diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index dc3aa102be78..3017fc96a5e3 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -28,7 +28,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -58,7 +58,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}
diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml
index b1af44736730..1074a5bba6f5 100644
--- a/.github/workflows/build_docker_images.yml
+++ b/.github/workflows/build_docker_images.yml
@@ -28,7 +28,7 @@ jobs:
uses: docker/setup-buildx-action@v1
- name: Check out code
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
- name: Find Changed Dockerfiles
id: file_changes
@@ -99,7 +99,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml
index f47645c1f659..8e8dc92cb57d 100644
--- a/.github/workflows/build_pr_documentation.yml
+++ b/.github/workflows/build_pr_documentation.yml
@@ -17,10 +17,10 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.10'
diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
new file mode 100644
index 000000000000..5ba158b46fde
--- /dev/null
+++ b/.github/workflows/codeql.yml
@@ -0,0 +1,22 @@
+---
+name: CodeQL Security Analysis For Github Actions
+
+on:
+ push:
+ branches: ["main"]
+ workflow_dispatch:
+ # pull_request:
+
+jobs:
+ codeql:
+ name: CodeQL Analysis
+ uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
+ permissions:
+ security-events: write
+ packages: read
+ actions: read
+ contents: read
+ with:
+ languages: '["actions","python"]'
+ queries: 'security-extended,security-and-quality'
+ runner: 'ubuntu-latest' #optional if need custom runner
diff --git a/.github/workflows/mirror_community_pipeline.yml b/.github/workflows/mirror_community_pipeline.yml
index ab4ded973047..73cced7c1394 100644
--- a/.github/workflows/mirror_community_pipeline.yml
+++ b/.github/workflows/mirror_community_pipeline.yml
@@ -24,7 +24,6 @@ jobs:
mirror_community_pipeline:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
-
runs-on: ubuntu-22.04
steps:
# Checkout to correct ref
@@ -39,37 +38,41 @@ jobs:
# If ref is 'refs/heads/main' => set 'main'
# Else it must be a tag => set {tag}
- name: Set checkout_ref and path_in_repo
+ env:
+ EVENT_NAME: ${{ github.event_name }}
+ EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
+ GITHUB_REF: ${{ github.ref }}
run: |
- if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
- if [ -z "${{ github.event.inputs.ref }}" ]; then
+ if [ "$EVENT_NAME" == "workflow_dispatch" ]; then
+ if [ -z "$EVENT_INPUT_REF" ]; then
echo "Error: Missing ref input"
exit 1
- elif [ "${{ github.event.inputs.ref }}" == "main" ]; then
+ elif [ "$EVENT_INPUT_REF" == "main" ]; then
echo "CHECKOUT_REF=refs/heads/main" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
- echo "CHECKOUT_REF=refs/tags/${{ github.event.inputs.ref }}" >> $GITHUB_ENV
- echo "PATH_IN_REPO=${{ github.event.inputs.ref }}" >> $GITHUB_ENV
+ echo "CHECKOUT_REF=refs/tags/$EVENT_INPUT_REF" >> $GITHUB_ENV
+ echo "PATH_IN_REPO=$EVENT_INPUT_REF" >> $GITHUB_ENV
fi
- elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
- echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
+ elif [ "$GITHUB_REF" == "refs/heads/main" ]; then
+ echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
# e.g. refs/tags/v0.28.1 -> v0.28.1
- echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
- echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
+ echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
+ echo "PATH_IN_REPO=$(echo $GITHUB_REF | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
fi
- name: Print env vars
run: |
echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
with:
ref: ${{ env.CHECKOUT_REF }}
# Setup + install dependencies
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
@@ -99,4 +102,4 @@ jobs:
- name: Report failure status
if: ${{ failure() }}
run: |
- pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
\ No newline at end of file
+ pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 8b7e57e91297..416d2af3fc2e 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -28,7 +28,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
- name: Pipeline Tests Artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -97,7 +97,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -119,7 +119,7 @@ jobs:
module: [models, schedulers, lora, others, single_file, examples]
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -167,7 +167,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_${{ matrix.module }}_cuda_test_reports
path: reports
@@ -184,7 +184,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_compile_test_reports
path: reports
@@ -228,7 +228,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -263,7 +263,7 @@ jobs:
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_cuda_big_gpu_test_reports
path: reports
@@ -280,7 +280,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -321,7 +321,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -355,7 +355,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -391,7 +391,7 @@ jobs:
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_cuda_${{ matrix.config.backend }}_reports
path: reports
@@ -408,7 +408,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -441,7 +441,7 @@ jobs:
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
@@ -466,7 +466,7 @@ jobs:
image: diffusers/diffusers-pytorch-cpu
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -474,7 +474,7 @@ jobs:
run: mkdir -p combined_reports
- name: Download all test reports
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v7
with:
path: artifacts
@@ -500,7 +500,7 @@ jobs:
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
- name: Upload consolidated report
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: consolidated_test_report
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
@@ -514,7 +514,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
-# uses: actions/checkout@v3
+# uses: actions/checkout@v6
# with:
# fetch-depth: 2
#
@@ -554,7 +554,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
-# uses: actions/upload-artifact@v4
+# uses: actions/upload-artifact@v6
# with:
# name: torch_mps_test_reports
# path: reports
@@ -570,7 +570,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
-# uses: actions/checkout@v3
+# uses: actions/checkout@v6
# with:
# fetch-depth: 2
#
@@ -610,7 +610,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
-# uses: actions/upload-artifact@v4
+# uses: actions/upload-artifact@v6
# with:
# name: torch_mps_test_reports
# path: reports
diff --git a/.github/workflows/notify_slack_about_release.yml b/.github/workflows/notify_slack_about_release.yml
index 612ad4e24503..6c0b96954e81 100644
--- a/.github/workflows/notify_slack_about_release.yml
+++ b/.github/workflows/notify_slack_about_release.yml
@@ -10,10 +10,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: '3.8'
diff --git a/.github/workflows/pr_dependency_test.yml b/.github/workflows/pr_dependency_test.yml
index b914d1076190..dfc35c41066f 100644
--- a/.github/workflows/pr_dependency_test.yml
+++ b/.github/workflows/pr_dependency_test.yml
@@ -18,9 +18,9 @@ jobs:
check_dependencies:
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml
index 13c228621f5c..3bdfb4ca99c6 100644
--- a/.github/workflows/pr_modular_tests.yml
+++ b/.github/workflows/pr_modular_tests.yml
@@ -1,3 +1,4 @@
+
name: Fast PR tests for Modular
on:
@@ -35,9 +36,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
@@ -55,9 +56,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
@@ -77,23 +78,13 @@ jobs:
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
- strategy:
- fail-fast: false
- matrix:
- config:
- - name: Fast PyTorch Modular Pipeline CPU tests
- framework: pytorch_pipelines
- runner: aws-highmemory-32-plus
- image: diffusers/diffusers-pytorch-cpu
- report: torch_cpu_modular_pipelines
-
- name: ${{ matrix.config.name }}
+ name: Fast PyTorch Modular Pipeline CPU tests
runs-on:
- group: ${{ matrix.config.runner }}
+ group: aws-highmemory-32-plus
container:
- image: ${{ matrix.config.image }}
+ image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
@@ -102,7 +93,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -118,22 +109,19 @@ jobs:
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
- if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-k "not Flax and not Onnx" \
- --make-reports=tests_${{ matrix.config.report }} \
+ --make-reports=tests_torch_cpu_modular_pipelines \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
- run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+ run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
- name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
+ name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
path: reports
-
-
diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml
index 83b2ab4edbf6..a02a40709fcc 100644
--- a/.github/workflows/pr_test_fetcher.yml
+++ b/.github/workflows/pr_test_fetcher.yml
@@ -28,7 +28,7 @@ jobs:
test_map: ${{ steps.set_matrix.outputs.test_map }}
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Install dependencies
@@ -42,7 +42,7 @@ jobs:
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
- name: Report fetched tests
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v6
with:
name: test_fetched
path: test_preparation.txt
@@ -83,7 +83,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -109,7 +109,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v6
with:
name: ${{ matrix.modules }}_test_reports
path: reports
@@ -138,7 +138,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -164,7 +164,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index 674e62ff443a..c0dfa89e776d 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -31,9 +31,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -51,9 +51,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -108,7 +108,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -153,7 +153,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
@@ -185,7 +185,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
@@ -236,7 +236,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -273,7 +273,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pr_main_test_reports
path: reports
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
index 468979d379c1..dd20bbe93250 100644
--- a/.github/workflows/pr_tests_gpu.yml
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -32,9 +32,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -52,9 +52,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -83,7 +83,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -100,7 +100,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -120,7 +120,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -170,7 +170,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -193,7 +193,7 @@ jobs:
module: [models, schedulers, lora, others]
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -239,7 +239,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -255,7 +255,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -287,7 +287,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: examples_test_reports
path: reports
diff --git a/.github/workflows/pr_torch_dependency_test.yml b/.github/workflows/pr_torch_dependency_test.yml
index 4b6160ff71e2..4a7e5ab37e47 100644
--- a/.github/workflows/pr_torch_dependency_test.yml
+++ b/.github/workflows/pr_torch_dependency_test.yml
@@ -18,9 +18,9 @@ jobs:
check_torch_dependencies:
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index 7b1c441d3dc0..4456f18c95bc 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -29,7 +29,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -46,7 +46,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -66,7 +66,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -98,7 +98,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -120,7 +120,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -155,7 +155,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -172,7 +172,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -199,7 +199,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_compile_test_reports
path: reports
@@ -216,7 +216,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -240,7 +240,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_xformers_test_reports
path: reports
@@ -256,7 +256,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -286,7 +286,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: examples_test_reports
path: reports
diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml
index 38cbffaa6315..fe6f6a265e89 100644
--- a/.github/workflows/push_tests_fast.yml
+++ b/.github/workflows/push_tests_fast.yml
@@ -54,7 +54,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -88,7 +88,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml
index 2d6feb592815..cc16d5f82cd0 100644
--- a/.github/workflows/push_tests_mps.yml
+++ b/.github/workflows/push_tests_mps.yml
@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -65,7 +65,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pr_torch_mps_test_reports
path: reports
diff --git a/.github/workflows/pypi_publish.yaml b/.github/workflows/pypi_publish.yaml
index dc36b6b024c5..214b996f5381 100644
--- a/.github/workflows/pypi_publish.yaml
+++ b/.github/workflows/pypi_publish.yaml
@@ -15,10 +15,10 @@ jobs:
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
steps:
- name: Checkout Repo
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: '3.8'
@@ -40,12 +40,12 @@ jobs:
steps:
- name: Checkout Repo
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
- name: Setup Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
python-version: "3.8"
diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml
index efdd6ea2b651..f667d715090d 100644
--- a/.github/workflows/release_tests_fast.yml
+++ b/.github/workflows/release_tests_fast.yml
@@ -27,7 +27,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -94,7 +94,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -116,7 +116,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -149,7 +149,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_cuda_${{ matrix.module }}_test_reports
path: reports
@@ -166,7 +166,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -205,7 +205,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -222,7 +222,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -247,7 +247,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_compile_test_reports
path: reports
@@ -264,7 +264,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -288,7 +288,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: torch_xformers_test_reports
path: reports
@@ -305,7 +305,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -336,7 +336,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: examples_test_reports
path: reports
diff --git a/.github/workflows/run_tests_from_a_pr.yml b/.github/workflows/run_tests_from_a_pr.yml
index fa8c579dd768..3e5462f5100f 100644
--- a/.github/workflows/run_tests_from_a_pr.yml
+++ b/.github/workflows/run_tests_from_a_pr.yml
@@ -57,7 +57,7 @@ jobs:
shell: bash -e {0}
- name: Checkout PR branch
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
ref: refs/pull/${{ inputs.pr_number }}/head
diff --git a/.github/workflows/ssh-pr-runner.yml b/.github/workflows/ssh-pr-runner.yml
index 49fa9c0ad24d..27246fb61348 100644
--- a/.github/workflows/ssh-pr-runner.yml
+++ b/.github/workflows/ssh-pr-runner.yml
@@ -27,7 +27,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
diff --git a/.github/workflows/ssh-runner.yml b/.github/workflows/ssh-runner.yml
index 917eb5b1b31a..4fbfad3dc7c6 100644
--- a/.github/workflows/ssh-runner.yml
+++ b/.github/workflows/ssh-runner.yml
@@ -35,7 +35,7 @@ jobs:
steps:
- name: Checkout diffusers
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
with:
fetch-depth: 2
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 27450ed4c7f2..b0a90e278550 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -15,10 +15,10 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v6
- name: Setup Python
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v6
with:
python-version: 3.8
diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml
index 4743dc352455..65334e086c83 100644
--- a/.github/workflows/trufflehog.yml
+++ b/.github/workflows/trufflehog.yml
@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Secret Scanning
diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml
index 6d2f2fc8dd9a..64bac653537d 100644
--- a/.github/workflows/typos.yml
+++ b/.github/workflows/typos.yml
@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: typos-action
uses: crate-ci/typos@v1.12.4
diff --git a/.github/workflows/update_metadata.yml b/.github/workflows/update_metadata.yml
index 92aea0369ba8..6e608883c13a 100644
--- a/.github/workflows/update_metadata.yml
+++ b/.github/workflows/update_metadata.yml
@@ -15,7 +15,7 @@ jobs:
shell: bash -l {0}
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v6
- name: Setup environment
run: |
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
deleted file mode 100644
index ec18df882641..000000000000
--- a/CONTRIBUTING.md
+++ /dev/null
@@ -1,506 +0,0 @@
-
-
-# How to contribute to Diffusers 🧨
-
-We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it!
-
-Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕.
-
-Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility.
-
-We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered.
-
-## Overview
-
-You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to
-the core library.
-
-In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
-
-* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
-* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose).
-* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues).
-* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
-* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
-* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
-* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples).
-* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22).
-* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md).
-
-As said before, **all contributions are valuable to the community**.
-In the following, we will explain each contribution a bit more in detail.
-
-For all contributions 4-9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr).
-
-### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
-
-Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to):
-- Reports of training or inference experiments in an attempt to share knowledge
-- Presentation of personal projects
-- Questions to non-official training examples
-- Project proposals
-- General feedback
-- Paper summaries
-- Asking for help on personal projects that build on top of the Diffusers library
-- General questions
-- Ethical questions regarding diffusion models
-- ...
-
-Every question that is asked on the forum or on Discord actively encourages the community to publicly
-share knowledge and might very well help a beginner in the future who has the same question you're
-having. Please do pose any questions you might have.
-In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.
-
-**Please** keep in mind that the more effort you put into asking or answering a question, the higher
-the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
-In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
-
-**NOTE about channels**:
-[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
-In addition, questions and answers posted in the forum can easily be linked to.
-In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication.
-While it will most likely take less time for you to get an answer to your question on Discord, your
-question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers.
-
-### 2. Opening new issues on the GitHub issues tab
-
-The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of
-the problems they encounter. So thank you for reporting an issue.
-
-Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design.
-
-In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
-
-**Please consider the following guidelines when opening a new issue**:
-- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues).
-- Please never report a new issue on another (related) issue. If another issue is highly related, please
-open a new issue nevertheless and link to the related issue.
-- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English.
-- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version.
-- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues.
-
-New issues usually include the following.
-
-#### 2.1. Reproducible, minimal bug reports
-
-A bug report should always have a reproducible code snippet and be as minimal and concise as possible.
-This means in more detail:
-- Narrow the bug down as much as you can, **do not just dump your whole code file**.
-- Format your code.
-- Do not include any external libraries except for Diffusers depending on them.
-- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
-- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it.
-- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
-- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
-
-For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
-
-You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml).
-
-#### 2.2. Feature requests
-
-A world-class feature request addresses the following points:
-
-1. Motivation first:
-* Is it related to a problem/frustration with the library? If so, please explain
-why. Providing a code snippet that demonstrates the problem is best.
-* Is it related to something you would need for a project? We'd love to hear
-about it!
-* Is it something you worked on and think could benefit the community?
-Awesome! Tell us what problem it solved for you.
-2. Write a *full paragraph* describing the feature;
-3. Provide a **code snippet** that demonstrates its future use;
-4. In case this is related to a paper, please attach a link;
-5. Attach any additional information (drawings, screenshots, etc.) you think may help.
-
-You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).
-
-#### 2.3 Feedback
-
-Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.
-If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.
-
-You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
-
-#### 2.4 Technical questions
-
-Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on
-why this part of the code is difficult to understand.
-
-You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml).
-
-#### 2.5 Proposal to add a new model, scheduler, or pipeline
-
-If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information:
-
-* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release.
-* Link to any of its open-source implementation.
-* Link to the model weights if they are available.
-
-If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget
-to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it.
-
-You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml).
-
-### 3. Answering issues on the GitHub issues tab
-
-Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct.
-Some tips to give a high-quality answer to an issue:
-- Be as concise and minimal as possible.
-- Stay on topic. An answer to the issue should concern the issue and only the issue.
-- Provide links to code, papers, or other sources that prove or encourage your point.
-- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet.
-
-Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great
-help to the maintainers if you can answer such issues, encouraging the author of the issue to be
-more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
-
-If you have verified that the issued bug report is correct and requires a correction in the source code,
-please have a look at the next sections.
-
-For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
-
-### 4. Fixing a "Good first issue"
-
-*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already
-explains how a potential solution should look so that it is easier to fix.
-If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios:
-- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it.
-- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR.
-- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR.
-
-
-### 5. Contribute to the documentation
-
-A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly
-valuable contribution**.
-
-Contributing to the library can have many forms:
-
-- Correcting spelling or grammatical errors.
-- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we are very happy if you take some time to correct it.
-- Correct the shape or dimensions of a docstring input or output tensor.
-- Clarify documentation that is hard to understand or incorrect.
-- Update outdated code examples.
-- Translating the documentation to another language.
-
-Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source).
-
-Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally.
-
-
-### 6. Contribute a community pipeline
-
-[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user.
-Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview).
-We support two types of pipelines:
-
-- Official Pipelines
-- Community Pipelines
-
-Both official and community pipelines follow the same design and consist of the same type of components.
-
-Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code
-resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
-In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested.
-They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution.
-
-The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all
-possible ways diffusion models can be used for inference, but some of them may be of interest to the community.
-Officially released diffusion pipelines,
-such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures
-high quality of maintenance, no backward-breaking code changes, and testing.
-More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library.
-
-To add a community pipeline, one should add a .py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline.
-
-An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400).
-
-Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors.
-
-Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the
-core package.
-
-### 7. Contribute to training examples
-
-Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples).
-
-We support two types of training examples:
-
-- Official training examples
-- Research training examples
-
-Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders.
-The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community.
-This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
-If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
-
-Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
-training examples, it is required to clone the repository:
-
-```bash
-git clone https://github.com/huggingface/diffusers
-```
-
-as well as to install all additional dependencies required for training:
-
-```bash
-cd diffusers
-pip install -r examples//requirements.txt
-```
-
-Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
-
-Training examples of the Diffusers library should adhere to the following philosophy:
-- All the code necessary to run the examples should be found in a single Python file.
-- One should be able to run the example from the command line with `python .py --args`.
-- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials.
-
-To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like.
-We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated
-with Diffusers.
-Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include:
-- An example command on how to run the example script as shown [here e.g.](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch).
-- A link to some training results (logs, models, ...) that show what the user can expect as shown [here e.g.](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
-- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
-
-If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples.
-
-### 8. Fixing a "Good second issue"
-
-*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are
-usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
-The issue description usually gives less guidance on how to fix the issue and requires
-a decent understanding of the library by the interested contributor.
-If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR.
-Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged.
-
-### 9. Adding pipelines, models, schedulers
-
-Pipelines, models, and schedulers are the most important pieces of the Diffusers library.
-They provide easy access to state-of-the-art diffusion technologies and thus allow the community to
-build powerful generative AI applications.
-
-By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem.
-
-Diffusers has a couple of open feature requests for all three components - feel free to gloss over them
-if you don't know yet what specific component you would like to add:
-- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
-- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
-
-Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) a read to better understand the design of any of the three components. Please be aware that
-we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
-as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
-open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
-pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
-
-Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
-original author directly on the PR so that they can follow the progress and potentially help with questions.
-
-If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
-
-## How to write a good issue
-
-**The better your issue is written, the higher the chances that it will be quickly resolved.**
-
-1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose).
-2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers".
-3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data.
-4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets.
-5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better.
-6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information.
-7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library.
-
-## How to write a good PR
-
-1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged.
-2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once.
-3. If helpful, try to add a code snippet that displays an example of how your addition can be used.
-4. The title of your pull request should be a summary of its contribution.
-5. If your pull request addresses an issue, please mention the issue number in
-the pull request description to make sure they are linked (and people
-consulting the issue know you are working on it);
-6. To indicate a work in progress please prefix the title with `[WIP]`. These
-are useful to avoid duplicated work, and to differentiate it from PRs ready
-to be merged;
-7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue).
-8. Make sure existing tests pass;
-9. Add high-coverage tests. No quality testing = no merge.
-- If you are adding new `@slow` tests, make sure they pass using
-`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
-CircleCI does not run the slow tests, but GitHub Actions does every night!
-10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example.
-11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
-[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files.
-If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
-to this dataset.
-
-## How to open a PR
-
-Before writing code, we strongly advise you to search through the existing PRs or
-issues to make sure that nobody is already working on the same thing. If you are
-unsure, it is always a good idea to open an issue to get some feedback.
-
-You will need basic `git` proficiency to be able to contribute to
-🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest
-manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
-Git](https://git-scm.com/book/en/v2) is a very good reference.
-
-Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)):
-
-1. Fork the [repository](https://github.com/huggingface/diffusers) by
-clicking on the 'Fork' button on the repository's page. This creates a copy of the code
-under your GitHub user account.
-
-2. Clone your fork to your local disk, and add the base repository as a remote:
-
- ```bash
- $ git clone git@github.com:/diffusers.git
- $ cd diffusers
- $ git remote add upstream https://github.com/huggingface/diffusers.git
- ```
-
-3. Create a new branch to hold your development changes:
-
- ```bash
- $ git checkout -b a-descriptive-name-for-my-changes
- ```
-
-**Do not** work on the `main` branch.
-
-4. Set up a development environment by running the following command in a virtual environment:
-
- ```bash
- $ pip install -e ".[dev]"
- ```
-
-If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the
-library.
-
-5. Develop the features on your branch.
-
-As you work on the features, you should make sure that the test suite
-passes. You should run the tests impacted by your changes like this:
-
- ```bash
- $ pytest tests/.py
- ```
-
-Before you run the tests, please make sure you install the dependencies required for testing. You can do so
-with this command:
-
- ```bash
- $ pip install -e ".[test]"
- ```
-
-You can also run the full test suite with the following command, but it takes
-a beefy machine to produce a result in a decent amount of time now that
-Diffusers has grown a lot. Here is the command for it:
-
- ```bash
- $ make test
- ```
-
-🧨 Diffusers relies on `ruff` and `isort` to format its source code
-consistently. After you make changes, apply automatic style corrections and code verifications
-that can't be automated in one go with:
-
- ```bash
- $ make style
- ```
-
-🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
-control runs in CI, however, you can also run the same checks with:
-
- ```bash
- $ make quality
- ```
-
-Once you're happy with your changes, add changed files using `git add` and
-make a commit with `git commit` to record your changes locally:
-
- ```bash
- $ git add modified_file.py
- $ git commit -m "A descriptive message about your changes."
- ```
-
-It is a good idea to sync your copy of the code with the original
-repository regularly. This way you can quickly account for changes:
-
- ```bash
- $ git pull upstream main
- ```
-
-Push the changes to your account using:
-
- ```bash
- $ git push -u origin a-descriptive-name-for-my-changes
- ```
-
-6. Once you are satisfied, go to the
-webpage of your fork on GitHub. Click on 'Pull request' to send your changes
-to the project maintainers for review.
-
-7. It's ok if maintainers ask you for changes. It happens to core contributors
-too! So everyone can see the changes in the Pull request, work in your local
-branch and push the changes to your fork. They will automatically appear in
-the pull request.
-
-### Tests
-
-An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
-the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests).
-
-We like `pytest` and `pytest-xdist` because it's faster. From the root of the
-repository, here's how to run tests with `pytest` for the library:
-
-```bash
-$ python -m pytest -n auto --dist=loadfile -s -v ./tests/
-```
-
-In fact, that's how `make test` is implemented!
-
-You can specify a smaller set of tests in order to test only the feature
-you're working on.
-
-By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to
-`yes` to run them. This will download many gigabytes of models — make sure you
-have enough disk space and a good Internet connection, or a lot of patience!
-
-```bash
-$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
-```
-
-`unittest` is fully supported, here's how to run tests with it:
-
-```bash
-$ python -m unittest discover -s tests -t . -v
-$ python -m unittest discover -s examples -t examples -v
-```
-
-### Syncing forked main with upstream (HuggingFace) main
-
-To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs,
-when syncing the main branch of a forked repository, please, follow these steps:
-1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main.
-2. If a PR is absolutely necessary, use the following steps after checking out your branch:
-```bash
-$ git checkout -b your-branch-for-syncing
-$ git pull --squash --no-commit upstream main
-$ git commit -m ''
-$ git push --set-upstream origin your-branch-for-syncing
-```
-
-### Style guide
-
-For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 120000
index 000000000000..53de38ca21e3
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1 @@
+docs/source/en/conceptual/contribution.md
\ No newline at end of file
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index f0cb0164436e..46e241d817b5 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -54,6 +54,8 @@
title: Batch inference
- local: training/distributed_inference
title: Distributed inference
+ - local: hybrid_inference/overview
+ title: Remote inference
title: Inference
- isExpanded: false
sections:
@@ -88,17 +90,6 @@
title: FreeU
title: Community optimizations
title: Inference optimization
-- isExpanded: false
- sections:
- - local: hybrid_inference/overview
- title: Overview
- - local: hybrid_inference/vae_decode
- title: VAE Decode
- - local: hybrid_inference/vae_encode
- title: VAE Encode
- - local: hybrid_inference/api_reference
- title: API Reference
- title: Hybrid Inference
- isExpanded: false
sections:
- local: modular_diffusers/overview
@@ -270,6 +261,8 @@
title: Outputs
- local: api/quantization
title: Quantization
+ - local: hybrid_inference/api_reference
+ title: Remote inference
- local: api/parallel
title: Parallel inference
title: Main Classes
@@ -353,6 +346,8 @@
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
+ - local: api/models/glm_image_transformer2d
+ title: GlmImageTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -367,6 +362,8 @@
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
+ - local: api/models/ltx2_video_transformer3d
+ title: LTX2VideoTransformer3DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
@@ -443,6 +440,10 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
+ - local: api/models/autoencoderkl_audio_ltx_2
+ title: AutoencoderKLLTX2Audio
+ - local: api/models/autoencoderkl_ltx_2
+ title: AutoencoderKLLTX2Video
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
@@ -495,6 +496,8 @@
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
+ - local: api/pipelines/bria_fibo_edit
+ title: Bria Fibo Edit
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -541,6 +544,8 @@
title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
+ - local: api/pipelines/glm_image
+ title: GLM-Image
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
@@ -678,6 +683,8 @@
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
+ - local: api/pipelines/ltx2
+ title: LTX-2
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/mochi
diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md
index c93dcad43821..6a2d74892cfa 100644
--- a/docs/source/en/api/cache.md
+++ b/docs/source/en/api/cache.md
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] apply_faster_cache
-### FirstBlockCacheConfig
+## FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index 7911bc2b2332..bbae6a9020af 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -33,6 +33,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
+- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP]
@@ -62,6 +63,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
+## LTX2LoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin
+
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
diff --git a/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
new file mode 100644
index 000000000000..d0024474e9e0
--- /dev/null
+++ b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
@@ -0,0 +1,29 @@
+
+
+# AutoencoderKLLTX2Audio
+
+The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLLTX2Audio
+
+vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
+```
+
+## AutoencoderKLLTX2Audio
+
+[[autodoc]] AutoencoderKLLTX2Audio
+ - encode
+ - decode
+ - all
\ No newline at end of file
diff --git a/docs/source/en/api/models/autoencoderkl_ltx_2.md b/docs/source/en/api/models/autoencoderkl_ltx_2.md
new file mode 100644
index 000000000000..1dbf516c017a
--- /dev/null
+++ b/docs/source/en/api/models/autoencoderkl_ltx_2.md
@@ -0,0 +1,29 @@
+
+
+# AutoencoderKLLTX2Video
+
+The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLLTX2Video
+
+vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
+```
+
+## AutoencoderKLLTX2Video
+
+[[autodoc]] AutoencoderKLLTX2Video
+ - decode
+ - encode
+ - all
diff --git a/docs/source/en/api/models/controlnet_flux.md b/docs/source/en/api/models/controlnet_flux.md
index 6b230d90fba3..ec0370c19e06 100644
--- a/docs/source/en/api/models/controlnet_flux.md
+++ b/docs/source/en/api/models/controlnet_flux.md
@@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co
## FluxControlNetOutput
-[[autodoc]] models.controlnet_flux.FluxControlNetOutput
\ No newline at end of file
+[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput
\ No newline at end of file
diff --git a/docs/source/en/api/models/controlnet_sparsectrl.md b/docs/source/en/api/models/controlnet_sparsectrl.md
index b9e81dc57eeb..0aa9848d0d2b 100644
--- a/docs/source/en/api/models/controlnet_sparsectrl.md
+++ b/docs/source/en/api/models/controlnet_sparsectrl.md
@@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr
## SparseControlNetOutput
-[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput
+[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput
diff --git a/docs/source/en/api/models/glm_image_transformer2d.md b/docs/source/en/api/models/glm_image_transformer2d.md
new file mode 100644
index 000000000000..7a18d1050075
--- /dev/null
+++ b/docs/source/en/api/models/glm_image_transformer2d.md
@@ -0,0 +1,18 @@
+
+
+# GlmImageTransformer2DModel
+
+A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO).
+
+## GlmImageTransformer2DModel
+
+[[autodoc]] GlmImageTransformer2DModel
diff --git a/docs/source/en/api/models/ltx2_video_transformer3d.md b/docs/source/en/api/models/ltx2_video_transformer3d.md
new file mode 100644
index 000000000000..9faab8695468
--- /dev/null
+++ b/docs/source/en/api/models/ltx2_video_transformer3d.md
@@ -0,0 +1,26 @@
+
+
+# LTX2VideoTransformer3DModel
+
+A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import LTX2VideoTransformer3DModel
+
+transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+```
+
+## LTX2VideoTransformer3DModel
+
+[[autodoc]] LTX2VideoTransformer3DModel
diff --git a/docs/source/en/api/pipelines/bria_fibo_edit.md b/docs/source/en/api/pipelines/bria_fibo_edit.md
new file mode 100644
index 000000000000..b46dd78cdb90
--- /dev/null
+++ b/docs/source/en/api/pipelines/bria_fibo_edit.md
@@ -0,0 +1,33 @@
+
+
+# Bria Fibo Edit
+
+Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.
+Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.
+Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality
+
+## Usage
+_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
+
+Use the command below to log in:
+
+```bash
+hf auth login
+```
+
+
+## BriaFiboEditPipeline
+
+[[autodoc]] BriaFiboEditPipeline
+ - all
+ - __call__
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md
index cc52ffa09a6d..2b3b50c25e80 100644
--- a/docs/source/en/api/pipelines/chroma.md
+++ b/docs/source/en/api/pipelines/chroma.md
@@ -99,3 +99,9 @@ image.save("chroma-single-file.png")
[[autodoc]] ChromaImg2ImgPipeline
- all
- __call__
+
+## ChromaInpaintPipeline
+
+[[autodoc]] ChromaInpaintPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/chronoedit.md b/docs/source/en/api/pipelines/chronoedit.md
index 48e70ab9e55e..5e7057f9ccb8 100644
--- a/docs/source/en/api/pipelines/chronoedit.md
+++ b/docs/source/en/api/pipelines/chronoedit.md
@@ -30,6 +30,10 @@
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
+Available Models/LoRAs:
+- [nvidia/ChronoEdit-14B-Diffusers](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers)
+- [nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora)
+- [nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora)
### Image Editing
@@ -100,6 +104,7 @@ Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.pn
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
@@ -109,9 +114,8 @@ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encod
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
-lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
-pipe.load_lora_weights(lora_path)
-pipe.fuse_lora(lora_scale=1.0)
+pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
+pipe.fuse_lora(adapter_names=["distill"], lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")
@@ -145,6 +149,57 @@ export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
+### Inference with Multiple LoRAs
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+from diffusers.schedulers import UniPCMultistepScheduler
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+from PIL import Image
+
+model_id = "nvidia/ChronoEdit-14B-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
+pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora", weight_name="paintbrush_lora_diffusers.safetensors", adapter_name="paintbrush")
+pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
+pipe.fuse_lora(adapter_names=["paintbrush", "distill"], lora_scale=1.0)
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
+pipe.to("cuda")
+
+image = load_image(
+ "https://raw.githubusercontent.com/nv-tlabs/ChronoEdit/refs/heads/main/assets/images/input_paintbrush.png"
+)
+max_area = 720 * 1280
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+print("width", width, "height", height)
+image = image.resize((width, height))
+prompt = (
+ "Turn the pencil sketch in the image into an actual object that is consistent with the image’s content. The user wants to change the sketch to a crown and a hat."
+)
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=5,
+ num_inference_steps=8,
+ guidance_scale=1.0,
+ enable_temporal_reasoning=False,
+ num_temporal_reasoning_steps=0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output_1.png")
+```
+
## ChronoEditPipeline
[[autodoc]] ChronoEditPipeline
diff --git a/docs/source/en/api/pipelines/diffedit.md b/docs/source/en/api/pipelines/diffedit.md
index 9734ca2eabc3..670b7bb4fca0 100644
--- a/docs/source/en/api/pipelines/diffedit.md
+++ b/docs/source/en/api/pipelines/diffedit.md
@@ -21,7 +21,7 @@ The abstract from the paper is:
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
-The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
+The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).
This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️
diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md
index 393e0d03c341..4ace2f3b3aa0 100644
--- a/docs/source/en/api/pipelines/flux2.md
+++ b/docs/source/en/api/pipelines/flux2.md
@@ -35,5 +35,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
## Flux2Pipeline
[[autodoc]] Flux2Pipeline
+ - all
+ - __call__
+
+## Flux2KleinPipeline
+
+[[autodoc]] Flux2KleinPipeline
- all
- __call__
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md
new file mode 100644
index 000000000000..a99832787847
--- /dev/null
+++ b/docs/source/en/api/pipelines/glm_image.md
@@ -0,0 +1,95 @@
+
+
+# GLM-Image
+
+## Overview
+
+GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.
+
+Model architecture: a hybrid autoregressive + diffusion decoder design、
+
++ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library.
++ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.
+
+Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.
+
++ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.
++ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.
+
+GLM-Image supports both text-to-image and image-to-image generation within a single model
+
++ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.
++ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.
+
+This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image).
+
+## Usage examples
+
+### Text to Image Generation
+
+```python
+import torch
+from diffusers.pipelines.glm_image import GlmImagePipeline
+
+pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
+prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
+image = pipe(
+ prompt=prompt,
+ height=32 * 32,
+ width=36 * 32,
+ num_inference_steps=30,
+ guidance_scale=1.5,
+ generator=torch.Generator(device="cuda").manual_seed(42),
+).images[0]
+
+image.save("output_t2i.png")
+```
+
+### Image to Image Generation
+
+```python
+import torch
+from diffusers.pipelines.glm_image import GlmImagePipeline
+from PIL import Image
+
+pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
+image_path = "cond.jpg"
+prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator."
+image = Image.open(image_path).convert("RGB")
+image = pipe(
+ prompt=prompt,
+ image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1]
+ height=33 * 32,
+ width=32 * 32,
+ num_inference_steps=30,
+ guidance_scale=1.5,
+ generator=torch.Generator(device="cuda").manual_seed(42),
+).images[0]
+
+image.save("output_i2i.png")
+```
+
++ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model.
+
+## GlmImagePipeline
+
+[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline
+ - all
+ - __call__
+
+## GlmImagePipelineOutput
+
+[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput
diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md
new file mode 100644
index 000000000000..4c6860daf024
--- /dev/null
+++ b/docs/source/en/api/pipelines/ltx2.md
@@ -0,0 +1,47 @@
+
+
+# LTX-2
+
+
+
+
+
+LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
+
+You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
+
+The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
+
+## LTX2Pipeline
+
+[[autodoc]] LTX2Pipeline
+ - all
+ - __call__
+
+## LTX2ImageToVideoPipeline
+
+[[autodoc]] LTX2ImageToVideoPipeline
+ - all
+ - __call__
+
+## LTX2LatentUpsamplePipeline
+
+[[autodoc]] LTX2LatentUpsamplePipeline
+ - all
+ - __call__
+
+## LTX2PipelineOutput
+
+[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput
diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md
index 940144538a35..68658f41dabc 100644
--- a/docs/source/en/api/pipelines/ltx_video.md
+++ b/docs/source/en/api/pipelines/ltx_video.md
@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
+ - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
Show example code
-
+
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
@@ -474,6 +474,12 @@ export_to_video(video, "output.mp4", fps=24)
+## LTXI2VLongMultiPromptPipeline
+
+[[autodoc]] LTXI2VLongMultiPromptPipeline
+ - all
+ - __call__
+
## LTXPipeline
[[autodoc]] LTXPipeline
diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md
index b3dd3dd93618..ee3dd3b28e4d 100644
--- a/docs/source/en/api/pipelines/qwenimage.md
+++ b/docs/source/en/api/pipelines/qwenimage.md
@@ -95,7 +95,7 @@ image.save("qwen_fewsteps.png")
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
-```
+```py
import torch
from PIL import Image
from diffusers import QwenImageEditPlusPipeline
@@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
image = pipe(
- image=[image_1, image_2],
- prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
+ image=[image_1, image_2],
+ prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
num_inference_steps=50
).images[0]
```
+## Performance
+
+### torch.compile
+
+Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
+
+```python
+import torch
+from diffusers import QwenImagePipeline
+
+pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
+pipe.transformer = torch.compile(pipe.transformer)
+
+# First call triggers compilation (~7s overhead)
+# Subsequent calls run at ~2.4x faster
+image = pipe("a cat", num_inference_steps=50).images[0]
+```
+
+### Batched Inference with Variable-Length Prompts
+
+When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
+
+```python
+# CFG with different prompt lengths works correctly
+image = pipe(
+ prompt="A cat",
+ negative_prompt="blurry, low quality, distorted",
+ true_cfg_scale=3.5,
+ num_inference_steps=50,
+).images[0]
+```
+
+For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
+
## QwenImagePipeline
[[autodoc]] QwenImagePipeline
diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md
index 6730f1551607..e1829bc409eb 100644
--- a/docs/source/en/api/pipelines/skyreels_v2.md
+++ b/docs/source/en/api/pipelines/skyreels_v2.md
@@ -37,7 +37,8 @@ The following SkyReels-V2 models are supported in Diffusers:
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
-- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
+
+This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
> [!TIP]
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
index 6aab6c5b33b9..d5fdbbfe0f95 100644
--- a/docs/source/en/api/pipelines/wan.md
+++ b/docs/source/en/api/pipelines/wan.md
@@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
-
-
-
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md
index 865aaba5ebb6..b538cb350481 100644
--- a/docs/source/en/hybrid_inference/api_reference.md
+++ b/docs/source/en/hybrid_inference/api_reference.md
@@ -1,9 +1,11 @@
-# Hybrid Inference API Reference
+# Remote inference
-## Remote Decode
+Remote inference provides access to an [Inference Endpoint](https://huggingface.co/docs/inference-endpoints/index) to offload local generation requirements for decoding and encoding.
+
+## remote_decode
[[autodoc]] utils.remote_utils.remote_decode
-## Remote Encode
+## remote_encode
[[autodoc]] utils.remote_utils.remote_encode
diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md
index 7ed1bbb88b3f..1384be9b7348 100644
--- a/docs/source/en/hybrid_inference/overview.md
+++ b/docs/source/en/hybrid_inference/overview.md
@@ -10,51 +10,296 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Hybrid Inference
+# Remote inference
-**Empowering local AI builders with Hybrid Inference**
+> [!TIP]
+> This is currently an experimental feature, and if you have any feedback, please feel free to leave it [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
+Remote inference offloads the decoding and encoding process to a remote endpoint to relax the memory requirements for local inference with large models. This feature is powered by [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index). Refer to the table below for the supported models and endpoint.
-> [!TIP]
-> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
-> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
+| Model | Endpoint | Checkpoint | Support |
+|---|---|---|---|
+| Stable Diffusion v1 | https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud | [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) | encode/decode |
+| Stable Diffusion XL | https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud | [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) | encode/decode |
+| Flux | https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | encode/decode |
+| HunyuanVideo | https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud | [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | decode |
+
+This guide will show you how to encode and decode latents with remote inference.
+
+## Encoding
+
+Encoding converts images and videos into latent representations. Refer to the table below for the supported VAEs.
+
+Pass an image to [`~utils.remote_encode`] to encode it. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
+
+```py
+import torch
+from diffusers import FluxPipeline
+from diffusers.utils import load_image
+from diffusers.utils.remote_utils import remote_encode
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ torch_dtype=torch.float16,
+ vae=None,
+ device_map="cuda"
+)
+
+init_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+)
+init_image = init_image.resize((768, 512))
+
+init_latent = remote_encode(
+ endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
+ image=init_image,
+ scaling_factor=0.3611,
+ shift_factor=0.1159
+)
+```
+
+## Decoding
+
+Decoding converts latent representations back into images or videos. Refer to the table below for the available and supported VAEs.
+
+Set the output type to `"latent"` in the pipeline and set the `vae` to `None`. Pass the latents to the [`~utils.remote_decode`] function. For Flux, the latents are packed so the `height` and `width` also need to be passed. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
+
+
+
+
+```py
+from diffusers import FluxPipeline
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ torch_dtype=torch.bfloat16,
+ vae=None,
+ device_map="cuda"
+)
+
+prompt = """
+A photorealistic Apollo-era photograph of a cat in a small astronaut suit with a bubble helmet, standing on the Moon and holding a flagpole planted in the dusty lunar soil. The flag shows a colorful paw-print emblem. Earth glows in the black sky above the stark gray surface, with sharp shadows and high-contrast lighting like vintage NASA photos.
+"""
+
+latent = pipeline(
+ prompt=prompt,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ output_type="latent",
+).images
+image = remote_decode(
+ endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ height=1024,
+ width=1024,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+image.save("image.jpg")
+```
+
+
+
+
+```py
+import torch
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
+)
+pipeline = HunyuanVideoPipeline.from_pretrained(
+ model_id, transformer=transformer, vae=None, torch_dtype=torch.float16, device_map="cuda"
+)
+
+latent = pipeline(
+ prompt="A cat walks on the grass, realistic",
+ height=320,
+ width=512,
+ num_frames=61,
+ num_inference_steps=30,
+ output_type="latent",
+).frames
+
+video = remote_decode(
+ endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ output_type="mp4",
+)
+
+if isinstance(video, bytes):
+ with open("video.mp4", "wb") as f:
+ f.write(video)
+```
+
+
+
+
+## Queuing
+
+Remote inference supports queuing to process multiple generation requests. While the current latent is being decoded, you can queue the next prompt.
+
+```py
+import queue
+import threading
+from IPython.display import display
+from diffusers import StableDiffusionXLPipeline
+
+def decode_worker(q: queue.Queue):
+ while True:
+ item = q.get()
+ if item is None:
+ break
+ image = remote_decode(
+ endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=item,
+ scaling_factor=0.13025,
+ )
+ display(image)
+ q.task_done()
+
+q = queue.Queue()
+thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
+thread.start()
+
+def decode(latent: torch.Tensor):
+ q.put(latent)
+
+prompts = [
+ "A grainy Apollo-era style photograph of a cat in a snug astronaut suit with a bubble helmet, standing on the lunar surface and gripping a flag with a paw-print emblem. The gray Moon landscape stretches behind it, Earth glowing vividly in the black sky, shadows crisp and high-contrast.",
+ "A vintage 1960s sci-fi pulp magazine cover illustration of a heroic cat astronaut planting a flag on the Moon. Bold, saturated colors, exaggerated space gear, playful typography floating in the background, Earth painted in bright blues and greens.",
+ "A hyper-detailed cinematic shot of a cat astronaut on the Moon holding a fluttering flag, fur visible through the helmet glass, lunar dust scattering under its feet. The vastness of space and Earth in the distance create an epic, awe-inspiring tone.",
+ "A colorful cartoon drawing of a happy cat wearing a chunky, oversized spacesuit, proudly holding a flag with a big paw print on it. The Moon’s surface is simplified with craters drawn like doodles, and Earth in the sky has a smiling face.",
+ "A monochrome 1969-style press photo of a “first cat on the Moon” moment. The cat, in a tiny astronaut suit, stands by a planted flag, with grainy textures, scratches, and a blurred Earth in the background, mimicking old archival space photos."
+]
+
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ vae=None,
+ device_map="cuda"
+)
+
+pipeline.unet = pipeline.unet.to(memory_format=torch.channels_last)
+pipeline.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+_ = pipeline(
+ prompt=prompts[0],
+ output_type="latent",
+)
+
+for prompt in prompts:
+ latent = pipeline(
+ prompt=prompt,
+ output_type="latent",
+ ).images
+ decode(latent)
+
+q.put(None)
+thread.join()
+```
+
+## Benchmarks
+
+The tables demonstrate the memory requirements for encoding and decoding with Stable Diffusion v1.5 and SDXL on different GPUs.
+For the majority of these GPUs, the memory usage dictates whether other models (text encoders, UNet/transformer) need to be offloaded or required tiled encoding. The latter two techniques increases inference time and impacts quality.
+Encoding - Stable Diffusion v1.5
-## Why use Hybrid Inference?
+| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
+|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
-Hybrid Inference offers a fast and simple way to offload local generation requirements.
+
-- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
-- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
-- 💰 **Cost Effective:** It's free! 🤑
-- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
-- 🔧 **Developer-Friendly:** Simple requests, fast responses.
+Encoding SDXL
----
+| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
+|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
-## Available Models
+
-* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
-* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
-* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
+Decoding - Stable Diffusion v1.5
----
+| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
+| --- | --- | --- | --- | --- | --- |
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
+| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
+| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
+| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
+| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
-## Integrations
+
-* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
-* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
+Decoding SDXL
-## Changelog
+| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
+| --- | --- | --- | --- | --- | --- |
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
+| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
+| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
+| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
+| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
-- March 10 2025: Added VAE encode
-- March 2 2025: Initial release with VAE decoding
+
-## Contents
-The documentation is organized into three sections:
+## Resources
-* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
-* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
-* **API Reference** Dive into task-specific settings and parameters.
+- Remote inference is also supported in [SD.Next](https://github.com/vladmandic/sdnext) and [ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae).
+- Refer to the [Remote VAEs for decoding with Inference Endpoints](https://huggingface.co/blog/remote_vae) blog post to learn more.
\ No newline at end of file
diff --git a/docs/source/en/hybrid_inference/vae_decode.md b/docs/source/en/hybrid_inference/vae_decode.md
deleted file mode 100644
index 1457090550c7..000000000000
--- a/docs/source/en/hybrid_inference/vae_decode.md
+++ /dev/null
@@ -1,345 +0,0 @@
-# Getting Started: VAE Decode with Hybrid Inference
-
-VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
-
-## Memory
-
-These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
-
-For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
-
-SD v1.5
-
-| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
-| --- | --- | --- | --- | --- | --- |
-| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
-| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
-| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
-| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
-| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
-| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
-| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
-| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
-| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
-| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
-| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
-| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
-
-
-
-SDXL
-
-| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
-| --- | --- | --- | --- | --- | --- |
-| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
-| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
-| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
-| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
-| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
-| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
-| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
-| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
-| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
-| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
-| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
-| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
-
-
-
-## Available VAEs
-
-| | **Endpoint** | **Model** |
-|:-:|:-----------:|:--------:|
-| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
-| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
-| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
-| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
-
-
-> [!TIP]
-> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
-
-
-## Code
-
-> [!TIP]
-> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
-
-
-A helper method simplifies interacting with Hybrid Inference.
-
-```python
-from diffusers.utils.remote_utils import remote_decode
-```
-
-### Basic example
-
-Here, we show how to use the remote VAE on random tensors.
-
-Code
-
-```python
-image = remote_decode(
- endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
- scaling_factor=0.18215,
-)
-```
-
-
-
-
-
-
-
-Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
-
-Code
-
-```python
-image = remote_decode(
- endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
- height=1024,
- width=1024,
- scaling_factor=0.3611,
- shift_factor=0.1159,
-)
-```
-
-
-
-
-
-
-
-Finally, an example for HunyuanVideo.
-
-Code
-
-```python
-video = remote_decode(
- endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
- output_type="mp4",
-)
-with open("video.mp4", "wb") as f:
- f.write(video)
-```
-
-
-
-
-
-
-
-
-### Generation
-
-But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
-
-Code
-
-```python
-from diffusers import StableDiffusionPipeline
-
-pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- variant="fp16",
- vae=None,
-).to("cuda")
-
-prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
-
-latent = pipe(
- prompt=prompt,
- output_type="latent",
-).images
-image = remote_decode(
- endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=latent,
- scaling_factor=0.18215,
-)
-image.save("test.jpg")
-```
-
-
-
-
-
-
-
-Here’s another example with Flux.
-
-Code
-
-```python
-from diffusers import FluxPipeline
-
-pipe = FluxPipeline.from_pretrained(
- "black-forest-labs/FLUX.1-schnell",
- torch_dtype=torch.bfloat16,
- vae=None,
-).to("cuda")
-
-prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
-
-latent = pipe(
- prompt=prompt,
- guidance_scale=0.0,
- num_inference_steps=4,
- output_type="latent",
-).images
-image = remote_decode(
- endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=latent,
- height=1024,
- width=1024,
- scaling_factor=0.3611,
- shift_factor=0.1159,
-)
-image.save("test.jpg")
-```
-
-
-
-
-
-
-
-Here’s an example with HunyuanVideo.
-
-Code
-
-```python
-from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
-
-model_id = "hunyuanvideo-community/HunyuanVideo"
-transformer = HunyuanVideoTransformer3DModel.from_pretrained(
- model_id, subfolder="transformer", torch_dtype=torch.bfloat16
-)
-pipe = HunyuanVideoPipeline.from_pretrained(
- model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
-).to("cuda")
-
-latent = pipe(
- prompt="A cat walks on the grass, realistic",
- height=320,
- width=512,
- num_frames=61,
- num_inference_steps=30,
- output_type="latent",
-).frames
-
-video = remote_decode(
- endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=latent,
- output_type="mp4",
-)
-
-if isinstance(video, bytes):
- with open("video.mp4", "wb") as f:
- f.write(video)
-```
-
-
-
-
-
-
-
-
-### Queueing
-
-One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
-
-
-Code
-
-```python
-import queue
-import threading
-from IPython.display import display
-from diffusers import StableDiffusionPipeline
-
-def decode_worker(q: queue.Queue):
- while True:
- item = q.get()
- if item is None:
- break
- image = remote_decode(
- endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=item,
- scaling_factor=0.18215,
- )
- display(image)
- q.task_done()
-
-q = queue.Queue()
-thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
-thread.start()
-
-def decode(latent: torch.Tensor):
- q.put(latent)
-
-prompts = [
- "Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
- "Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
- "Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
- "Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
- "A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
- "Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
-]
-
-pipe = StableDiffusionPipeline.from_pretrained(
- "Lykon/dreamshaper-8",
- torch_dtype=torch.float16,
- vae=None,
-).to("cuda")
-
-pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
-pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
-_ = pipe(
- prompt=prompts[0],
- output_type="latent",
-)
-
-for prompt in prompts:
- latent = pipe(
- prompt=prompt,
- output_type="latent",
- ).images
- decode(latent)
-
-q.put(None)
-thread.join()
-```
-
-
-
-
-
-
-
-
-## Integrations
-
-* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
-* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md
deleted file mode 100644
index dd285fa25c03..000000000000
--- a/docs/source/en/hybrid_inference/vae_encode.md
+++ /dev/null
@@ -1,183 +0,0 @@
-# Getting Started: VAE Encode with Hybrid Inference
-
-VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
-
-## Memory
-
-These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
-
-For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
-
-SD v1.5
-
-| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
-|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
-| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
-| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
-| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
-| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
-| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
-| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
-| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
-| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
-| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
-| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
-| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
-| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
-| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
-| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
-| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
-| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
-| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
-| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
-| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
-| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
-
-
-
-
-SDXL
-
-| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
-|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
-| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
-| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
-| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
-| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
-| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
-| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
-| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
-| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
-| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
-| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
-| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
-| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
-| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
-| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
-| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
-| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
-| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
-| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
-| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
-| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
-
-
-
-## Available VAEs
-
-| | **Endpoint** | **Model** |
-|:-:|:-----------:|:--------:|
-| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
-| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
-| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
-
-
-> [!TIP]
-> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
-
-
-## Code
-
-> [!TIP]
-> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
-
-
-A helper method simplifies interacting with Hybrid Inference.
-
-```python
-from diffusers.utils.remote_utils import remote_encode
-```
-
-### Basic example
-
-Let's encode an image, then decode it to demonstrate.
-
-
-
-
-
-Code
-
-```python
-from diffusers.utils import load_image
-from diffusers.utils.remote_utils import remote_decode
-
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
-
-latent = remote_encode(
- endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
- scaling_factor=0.3611,
- shift_factor=0.1159,
-)
-
-decoded = remote_decode(
- endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=latent,
- scaling_factor=0.3611,
- shift_factor=0.1159,
-)
-```
-
-
-
-
-
-
-
-
-### Generation
-
-Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
-
-Code
-
-```python
-import torch
-from diffusers import StableDiffusionImg2ImgPipeline
-from diffusers.utils import load_image
-from diffusers.utils.remote_utils import remote_decode, remote_encode
-
-pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- variant="fp16",
- vae=None,
-).to("cuda")
-
-init_image = load_image(
- "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
-)
-init_image = init_image.resize((768, 512))
-
-init_latent = remote_encode(
- endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
- image=init_image,
- scaling_factor=0.18215,
-)
-
-prompt = "A fantasy landscape, trending on artstation"
-latent = pipe(
- prompt=prompt,
- image=init_latent,
- strength=0.75,
- output_type="latent",
-).images
-
-image = remote_decode(
- endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
- tensor=latent,
- scaling_factor=0.18215,
-)
-image.save("fantasy_landscape.jpg")
-```
-
-
-
-
-
-
-
-## Integrations
-
-* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
-* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
diff --git a/docs/source/en/modular_diffusers/custom_blocks.md b/docs/source/en/modular_diffusers/custom_blocks.md
index 1c311582264e..6ef8db613f7f 100644
--- a/docs/source/en/modular_diffusers/custom_blocks.md
+++ b/docs/source/en/modular_diffusers/custom_blocks.md
@@ -140,7 +140,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
type_hint=str,
required=True,
default="mask_image",
- description="""Output type from annotation predictions. Availabe options are
+ description="""Output type from annotation predictions. Available options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
@@ -256,7 +256,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
type_hint=str,
required=True,
default="mask_image",
- description="""Output type from annotation predictions. Availabe options are
+ description="""Output type from annotation predictions. Available options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
index a80309de19a6..74a868922799 100644
--- a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
+++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
@@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to
A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
-- It recieves the iteration variable from the loop wrapper.
+- It receives the iteration variable from the loop wrapper.
- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].
diff --git a/docs/source/en/optimization/cache.md b/docs/source/en/optimization/cache.md
index 6397c7d4cd2e..3854ecd469f8 100644
--- a/docs/source/en/optimization/cache.md
+++ b/docs/source/en/optimization/cache.md
@@ -68,6 +68,20 @@ config = FasterCacheConfig(
pipeline.transformer.enable_cache(config)
```
+## FirstBlockCache
+
+[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16
+)
+apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
+```
## TaylorSeer Cache
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
@@ -87,8 +101,7 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
-)
-pipe.to("cuda")
+).to("cuda")
config = TaylorSeerCacheConfig(
cache_interval=5,
@@ -97,4 +110,4 @@ config = TaylorSeerCacheConfig(
taylor_factors_dtype=torch.bfloat16,
)
pipe.transformer.enable_cache(config)
-```
\ No newline at end of file
+```
diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md
index 18cc109e0785..a501c4efc43a 100644
--- a/docs/source/en/quantization/torchao.md
+++ b/docs/source/en/quantization/torchao.md
@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
- quantzation_config=pipeline_quant_config,
+ quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
- quantzation_config=pipeline_quant_config,
+ quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
- quantzation_config=pipeline_quant_config,
+ quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md
index 22e8a30427b9..e7ec1480aabd 100644
--- a/docs/source/en/training/distributed_inference.md
+++ b/docs/source/en/training/distributed_inference.md
@@ -263,8 +263,8 @@ def main():
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
- "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
- )
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+ ).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)
@@ -314,6 +314,35 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
```
+### Unified Attention
+
+[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
+
+This hybrid approach leverages the strengths of both methods:
+- **Ulysses Attention** efficiently parallelizes across attention heads
+- **Ring Attention** handles very long sequences with minimal memory overhead
+- Together, they enable 2D parallelization across both heads and sequence dimensions
+
+[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
+Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
+
+```py
+pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
+```
+
+> [!TIP]
+> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
+
+We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
+
+| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
+|--------------------|------------------|-------------|------------------|
+| ulysses | 6670.789 | 7.50 | 33.85 |
+| ring | 13076.492 | 3.82 | 56.02 |
+| unified_balanced | 11068.705 | 4.52 | 33.85 |
+
+From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
+
### parallel_config
Pass `parallel_config` during model initialization to enable context parallelism.
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index cf0a1588f39b..8fb749d328c9 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -1929,6 +1929,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
if args.cache_latents:
latents_cache = []
+ # Store vae config before potential deletion
+ vae_scaling_factor = vae.config.scaling_factor
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
@@ -1940,6 +1942,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ else:
+ vae_scaling_factor = vae.config.scaling_factor
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
@@ -2109,13 +2113,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = vae.encode(pixel_values).latent_dist.sample()
if latents_mean is None and latents_std is None:
- model_input = model_input * vae.config.scaling_factor
+ model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
- model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
+ model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index 113d9b57398e..001934298abe 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -149,13 +149,13 @@ def get_args():
"--validation_prompt",
type=str,
default=None,
- help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
)
parser.add_argument(
"--validation_images",
type=str,
default=None,
- help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
+ help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.",
)
parser.add_argument(
"--validation_prompt_separator",
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index bcafe4ecf5d9..f6f2dc83a3f9 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -140,7 +140,7 @@ def get_args():
"--validation_prompt",
type=str,
default=None,
- help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
)
parser.add_argument(
"--validation_prompt_separator",
diff --git a/examples/community/pipeline_z_image_differential_img2img.py b/examples/community/pipeline_z_image_differential_img2img.py
new file mode 100644
index 000000000000..8bde065c4013
--- /dev/null
+++ b/examples/community/pipeline_z_image_differential_img2img.py
@@ -0,0 +1,844 @@
+# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import AutoTokenizer, PreTrainedModel
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.models.transformers import ZImageTransformer2DModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from pipeline_z_image_differential_img2img import ZImageDifferentialImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = ZImageDifferentialImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> init_image = load_image(
+ >>> "https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true",
+ >>> )
+
+ >>> mask = load_image(
+ >>> "https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true",
+ >>> )
+
+ >>> prompt = "painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art"
+
+ >>> image = pipe(
+ ... prompt,
+ ... image=init_image,
+ ... mask_image=mask,
+ ... strength=0.75,
+ ... num_inference_steps=9,
+ ... guidance_scale=0.0,
+ ... generator=torch.Generator("cuda").manual_seed(41),
+ ... ).images[0]
+ >>> image.save("image.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ZImageDifferentialImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The ZImage pipeline for image-to-image generation.
+
+ Args:
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`PreTrainedModel`]):
+ A text encoder model to encode text prompts.
+ tokenizer ([`AutoTokenizer`]):
+ A tokenizer to tokenize text prompts.
+ transformer ([`ZImageTransformer2DModel`]):
+ A ZImage transformer model to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: PreTrainedModel,
+ tokenizer: AutoTokenizer,
+ transformer: ZImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor,
+ vae_latent_channels=latent_channels,
+ do_normalize=False,
+ do_binarize=False,
+ do_convert_grayscale=True,
+ )
+
+ # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if do_classifier_free_guidance:
+ if negative_prompt is None:
+ negative_prompt = ["" for _ in prompt]
+ else:
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ assert len(prompt) == len(negative_prompt)
+ negative_prompt_embeds = self._encode_prompt(
+ prompt=negative_prompt,
+ device=device,
+ prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ )
+ else:
+ negative_prompt_embeds = []
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ max_sequence_length: int = 512,
+ ) -> List[torch.FloatTensor]:
+ device = device or self._execution_device
+
+ if prompt_embeds is not None:
+ return prompt_embeds
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ for i, prompt_item in enumerate(prompt):
+ messages = [
+ {"role": "user", "content": prompt_item},
+ ]
+ prompt_item = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True,
+ )
+ prompt[i] = prompt_item
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device)
+ prompt_masks = text_inputs.attention_mask.to(device).bool()
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_masks,
+ output_hidden_states=True,
+ ).hidden_states[-2]
+
+ embeddings_list = []
+
+ for i in range(len(prompt_embeds)):
+ embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
+
+ return embeddings_list
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ # Encode the input image
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != num_channels_latents:
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ else:
+ image_latents = image
+
+ # Handle batch size expansion
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+
+ # Add noise using flow matching scale_noise
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+
+ return latents, noise, image_latents, latent_image_ids
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ strength: float = 0.6,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ cfg_normalization: bool = False,
+ cfg_truncation: float = 1.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for image-to-image generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
+ list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
+ a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. Black pixels in the mask
+ are repainted while white pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ strength (`float`, *optional*, defaults to 0.6):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ height (`int`, *optional*, defaults to 1024):
+ The height in pixels of the generated image. If not provided, uses the input image height.
+ width (`int`, *optional*, defaults to 1024):
+ The width in pixels of the generated image. If not provided, uses the input image width.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ cfg_normalization (`bool`, *optional*, defaults to False):
+ Whether to apply configuration normalization.
+ cfg_truncation (`float`, *optional*, defaults to 1.0):
+ The truncation value for configuration.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
+ tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ # 1. Check inputs and validate strength
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # Get dimensions from the preprocessed image if not specified
+ if height is None:
+ height = init_image.shape[-2]
+ if width is None:
+ width = init_image.shape[-1]
+
+ vae_scale = self.vae_scale_factor * 2
+ if height % vae_scale != 0:
+ raise ValueError(
+ f"Height must be divisible by {vae_scale} (got {height}). "
+ f"Please adjust the height to a multiple of {vae_scale}."
+ )
+ if width % vae_scale != 0:
+ raise ValueError(
+ f"Width must be divisible by {vae_scale} (got {width}). "
+ f"Please adjust the width to a multiple of {vae_scale}."
+ )
+
+ device = self._execution_device
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+ self._cfg_normalization = cfg_normalization
+ self._cfg_truncation = cfg_truncation
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = len(prompt_embeds)
+
+ # If prompt_embeds is provided and prompt is None, skip encoding
+ if prompt_embeds is not None and prompt is None:
+ if self.do_classifier_free_guidance and negative_prompt_embeds is None:
+ raise ValueError(
+ "When `prompt_embeds` is provided without `prompt`, "
+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
+ )
+ else:
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.in_channels
+
+ # Repeat prompt_embeds for num_images_per_prompt
+ if num_images_per_prompt > 1:
+ prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
+ if self.do_classifier_free_guidance and negative_prompt_embeds:
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
+
+ actual_batch_size = batch_size * num_images_per_prompt
+
+ # Calculate latent dimensions for image_seq_len
+ latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ image_seq_len = (latent_height // 2) * (latent_width // 2)
+
+ # 5. Prepare timesteps
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ self.scheduler.sigma_min = 0.0
+ scheduler_kwargs = {"mu": mu}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+
+ # 6. Adjust timesteps based on strength
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(actual_batch_size)
+
+ # 7. Prepare latents from image
+ latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ actual_batch_size,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds[0].dtype,
+ device,
+ generator,
+ latents,
+ )
+ resize_mode = "default"
+ crops_coords = None
+
+ # start diff diff preparation
+ original_mask = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ masked_image = init_image * original_mask
+ original_mask, _ = self.prepare_mask_latents(
+ original_mask,
+ masked_image,
+ batch_size,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds[0].dtype,
+ device,
+ generator,
+ )
+ mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
+ mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)
+ masks = original_mask > mask_thresholds
+ # end diff diff preparation
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 8. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+ timestep = (1000 - timestep) / 1000
+ # Normalized time for time-aware config (0 at start, 1 at end)
+ t_norm = timestep[0].item()
+
+ # Handle cfg truncation
+ current_guidance_scale = self.guidance_scale
+ if (
+ self.do_classifier_free_guidance
+ and self._cfg_truncation is not None
+ and float(self._cfg_truncation) <= 1
+ ):
+ if t_norm > self._cfg_truncation:
+ current_guidance_scale = 0.0
+
+ # Run CFG only if configured AND scale is non-zero
+ apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
+
+ if apply_cfg:
+ latents_typed = latents.to(self.transformer.dtype)
+ latent_model_input = latents_typed.repeat(2, 1, 1, 1)
+ prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
+ timestep_model_input = timestep.repeat(2)
+ else:
+ latent_model_input = latents.to(self.transformer.dtype)
+ prompt_embeds_model_input = prompt_embeds
+ timestep_model_input = timestep
+
+ latent_model_input = latent_model_input.unsqueeze(2)
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
+
+ model_out_list = self.transformer(
+ latent_model_input_list,
+ timestep_model_input,
+ prompt_embeds_model_input,
+ )[0]
+
+ if apply_cfg:
+ # Perform CFG
+ pos_out = model_out_list[:actual_batch_size]
+ neg_out = model_out_list[actual_batch_size:]
+
+ noise_pred = []
+ for j in range(actual_batch_size):
+ pos = pos_out[j].float()
+ neg = neg_out[j].float()
+
+ pred = pos + current_guidance_scale * (pos - neg)
+
+ # Renormalization
+ if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
+ ori_pos_norm = torch.linalg.vector_norm(pos)
+ new_pos_norm = torch.linalg.vector_norm(pred)
+ max_new_norm = ori_pos_norm * float(self._cfg_normalization)
+ if new_pos_norm > max_new_norm:
+ pred = pred * (max_new_norm / new_pos_norm)
+
+ noise_pred.append(pred)
+
+ noise_pred = torch.stack(noise_pred, dim=0)
+ else:
+ noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
+
+ noise_pred = noise_pred.squeeze(2)
+ noise_pred = -noise_pred
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
+ assert latents.dtype == torch.float32
+
+ # start diff diff
+ image_latent = original_image_latents
+ latents_dtype = latents.dtype
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ image_latent = self.scheduler.scale_noise(
+ original_image_latents, torch.tensor([noise_timestep]), noise
+ )
+
+ mask = masks[i].to(latents_dtype)
+ latents = image_latent * mask + latents * (1 - mask)
+ # end diff diff
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = latents.to(self.vae.dtype)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ZImagePipelineOutput(images=image)
diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md
index 1d1777811387..ad5d61f1f9e2 100644
--- a/examples/dreambooth/README_flux2.md
+++ b/examples/dreambooth/README_flux2.md
@@ -1,14 +1,22 @@
-# DreamBooth training example for FLUX.2 [dev]
+# DreamBooth training example for FLUX.2 [dev] and FLUX 2 [klein]
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
+[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
+
+The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scripts shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://huggingface.co/black-forest-labs/FLUX.2-dev) and [FLUX 2 [klein]](https://huggingface.co/black-forest-labs/FLUX.2-klein).
-The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
+> [!NOTE]
+> **Model Variants**
+>
+> We support two FLUX model families:
+> - **FLUX.2 [dev]**: The full-size model using Mistral Small 3.1 as the text encoder. Very capable but memory intensive.
+> - **FLUX 2 [klein]**: Available in 4B and 9B parameter variants, using Qwen VL as the text encoder. Much more memory efficient and suitable for consumer hardware.
> [!NOTE]
> **Memory consumption**
>
-> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
-> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
+> FLUX.2 [dev] can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
+> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. FLUX 2 [klein] models (4B and 9B) are significantly more memory efficient alternatives. Below we provide some tips and tricks to reduce memory consumption during training.
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
@@ -17,7 +25,7 @@ The `train_dreambooth_lora_flux2.py` script shows how to implement the training
> [!NOTE]
> **Gated model**
>
-> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you've accepted the gate. Use the command below to log in:
```bash
hf auth login
@@ -88,20 +96,32 @@ snapshot_download(
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
-As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
+As mentioned, Flux2 LoRA training is *very* memory intensive (especially for FLUX.2 [dev]). Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
## Memory Optimizations
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
+
### Remote Text Encoder
-Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
+FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
This way, the text encoder model is not loaded into memory during training.
+
+> [!IMPORTANT]
+> **Remote text encoder is only supported for FLUX.2 [dev]**. FLUX 2 [klein] models use the Qwen VL text encoder and do not support remote text encoding.
+
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
+
+### FSDP Text Encoder
+FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
+This way, it distributes the memory cost across multiple nodes.
+
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
+
### Latent Caching
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
+
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
@@ -111,22 +131,29 @@ enable FP8 training by passing `--do_fp8_training`.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
+
### Gradient Checkpointing and Accumulation
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
+
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
+
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
+
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
+## Training Examples
+### FLUX.2 [dev] Training
+To perform DreamBooth with LoRA on FLUX.2 [dev], run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
@@ -158,19 +185,104 @@ accelerate launch train_dreambooth_lora_flux2.py \
--push_to_hub
```
-To better track our training experiments, we're using the following flags in the command above:
+### FLUX 2 [klein] Training
-* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
-* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+FLUX 2 [klein] models are more memory efficient alternatives available in 4B and 9B parameter variants. They use the Qwen VL text encoder instead of Mistral Small 3.1.
> [!NOTE]
-> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
+> The `--remote_text_encoder` flag is **not supported** for FLUX 2 [klein] models. The Qwen VL text encoder must be loaded locally, but offloading is still supported.
-## LoRA + DreamBooth
+**FLUX 2 [klein] 4B:**
-[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.2-klein-4B"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-flux2-klein-4b"
-Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+accelerate launch train_dreambooth_lora_flux2_klein.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --do_fp8_training \
+ --gradient_checkpointing \
+ --cache_latents \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --use_8bit_adam \
+ --gradient_accumulation_steps=4 \
+ --optimizer="adamW" \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+**FLUX 2 [klein] 9B:**
+
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.2-klein-9B"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-flux2-klein-9b"
+
+accelerate launch train_dreambooth_lora_flux2_klein.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --do_fp8_training \
+ --gradient_checkpointing \
+ --cache_latents \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --use_8bit_adam \
+ --gradient_accumulation_steps=4 \
+ --optimizer="adamW" \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+> [!NOTE]
+> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. Note that this will use more resources and may slow down the training in some cases.
+
+### FSDP on the transformer
+By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
+
+```shell
+distributed_type: FSDP
+fsdp_config:
+ fsdp_version: 2
+ fsdp_offload_params: false
+ fsdp_sharding_strategy: HYBRID_SHARD
+ fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
+ fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
+ fsdp_forward_prefetch: true
+ fsdp_sync_module_states: false
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_use_orig_params: false
+ fsdp_activation_checkpointing: true
+ fsdp_reshard_after_forward: true
+ fsdp_cpu_ram_efficient_loading: false
+```
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
@@ -183,8 +295,6 @@ to use prodigy, first make sure to install the prodigyopt library: `pip install
> [!TIP]
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
-To perform DreamBooth with LoRA, run:
-
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
@@ -248,13 +358,10 @@ the exact modules for LoRA training. Here are some examples of target modules yo
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
-
## Training Image-to-Image
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
-**important**
-
**Important**
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
@@ -311,5 +418,6 @@ we've added aspect ratio bucketing support which allows training on images with
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
-`
-Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
+
+
+Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
\ No newline at end of file
diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md
index 7972434b5e6f..8bddacf975d8 100644
--- a/examples/dreambooth/README_sana.md
+++ b/examples/dreambooth/README_sana.md
@@ -111,6 +111,25 @@ To better track our training experiments, we're using the following flags in the
## Notes
+### LoRA Rank and Alpha
+Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
+- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
+- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
+- lora_alpha vs. rank:
+This ratio dictates the LoRA's effective strength:
+lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
+lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
+lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
+
+> [!TIP]
+> A common starting point is to set `lora_alpha` equal to `rank`.
+> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
+> to give the LoRA updates more influence without increasing parameter count.
+> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
+> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
+
+### Additional CLI arguments
+
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
diff --git a/examples/dreambooth/test_dreambooth_lora_flux2_klein.py b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py
new file mode 100644
index 000000000000..0e5506e1a3eb
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py
@@ -0,0 +1,262 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRAFlux2Klein(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "dog"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein"
+ script_path = "examples/dreambooth/train_dreambooth_lora_flux2_klein.py"
+ transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
+
+ def test_dreambooth_lora_flux2(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_latent_caching(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_layers(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lora_layers {self.transformer_layer_type}
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names. In this test, we only params of
+ # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
+ starts_with_transformer = all(
+ key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --max_sequence_length 8
+ --checkpointing_steps=2
+ --text_encoder_out_layers 1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+
+ def test_dreambooth_lora_with_metadata(self):
+ # Use a `lora_alpha` that is different from `rank`.
+ lora_alpha = 8
+ rank = 4
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --lora_alpha={lora_alpha}
+ --rank={rank}
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --max_sequence_length 8
+ --text_encoder_out_layers 1
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
+ self.assertTrue(os.path.isfile(state_dict_file))
+
+ # Check if the metadata was properly serialized.
+ with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
+ metadata = f.metadata() or {}
+
+ metadata.pop("format", None)
+ raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
+ if raw:
+ raw = json.loads(raw)
+
+ loaded_lora_alpha = raw["transformer.lora_alpha"]
+ self.assertTrue(loaded_lora_alpha == lora_alpha)
+ loaded_lora_rank = raw["transformer.r"]
+ self.assertTrue(loaded_lora_rank == rank)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py
index 81306940af8f..317ed2c2b2e1 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux2.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux2.py
@@ -44,6 +44,7 @@
import warnings
from contextlib import nullcontext
from pathlib import Path
+from typing import Any
import numpy as np
import torch
@@ -75,13 +76,16 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
+ _to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
+ get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
+ wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -93,6 +97,9 @@
from diffusers.utils.torch_utils import is_compiled_module
+if getattr(torch, "distributed", None) is not None:
+ import torch.distributed as dist
+
if is_wandb_available():
import wandb
@@ -722,6 +729,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+ parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1219,7 +1227,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
- transformer.to(**transformer_to_kwargs)
+
+ is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
+ if not is_fsdp:
+ transformer.to(**transformer_to_kwargs)
+
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1263,17 +1275,42 @@ def unwrap_model(model):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
+ transformer_cls = type(unwrap_model(transformer))
+
+ # 1) Validate and pick the transformer model
+ modules_to_save: dict[str, Any] = {}
+ transformer_model = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), transformer_cls):
+ transformer_model = model
+ modules_to_save["transformer"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ if transformer_model is None:
+ raise ValueError("No transformer model found in 'models'")
+
+ # 2) Optionally gather FSDP state dict once
+ state_dict = accelerator.get_state_dict(model) if is_fsdp else None
+
+ # 3) Only main process materializes the LoRA state dict
+ transformer_lora_layers_to_save = None
if accelerator.is_main_process:
- transformer_lora_layers_to_save = None
- modules_to_save = {}
- for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
- modules_to_save["transformer"] = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ peft_kwargs = {}
+ if is_fsdp:
+ peft_kwargs["state_dict"] = state_dict
+
+ transformer_lora_layers_to_save = get_peft_model_state_dict(
+ unwrap_model(transformer_model) if is_fsdp else transformer_model,
+ **peft_kwargs,
+ )
- # make sure to pop weight so that corresponding model is not saved again
+ if is_fsdp:
+ transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1285,13 +1322,20 @@ def save_model_hook(models, weights, output_dir):
def load_model_hook(models, input_dir):
transformer_ = None
- while len(models) > 0:
- model = models.pop()
+ if not is_fsdp:
+ while len(models) > 0:
+ model = models.pop()
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+ else:
+ transformer_ = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ )
+ transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1507,6 +1551,21 @@ def _encode_single(prompt: str):
args.validation_prompt, text_encoding_pipeline
)
+ # Init FSDP for text encoder
+ if args.fsdp_text_encoder:
+ fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
+ text_encoder_fsdp = wrap_with_fsdp(
+ model=text_encoding_pipeline.text_encoder,
+ device=accelerator.device,
+ offload=args.offload,
+ limit_all_gathers=True,
+ use_orig_params=True,
+ fsdp_kwargs=fsdp_kwargs,
+ )
+
+ text_encoding_pipeline.text_encoder = text_encoder_fsdp
+ dist.barrier()
+
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1536,6 +1595,8 @@ def _encode_single(prompt: str):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
+ elif args.fsdp_text_encoder:
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1777,7 +1838,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1
- if accelerator.is_main_process:
+ if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1836,15 +1897,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# Save the lora layers
accelerator.wait_for_everyone()
+
+ if is_fsdp:
+ transformer = unwrap_model(transformer)
+ state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
- transformer = unwrap_model(transformer)
- if args.bnb_quantization_config_path is None:
- if args.upcast_before_saving:
- transformer.to(torch.float32)
- else:
- transformer = transformer.to(weight_dtype)
- transformer_lora_layers = get_peft_model_state_dict(transformer)
+ if is_fsdp:
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ state_dict = {
+ k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+ else:
+ state_dict = {
+ k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+
+ transformer_lora_layers = get_peft_model_state_dict(
+ transformer,
+ state_dict=state_dict,
+ )
+ transformer_lora_layers = {
+ k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
+ for k, v in transformer_lora_layers.items()
+ }
+
+ else:
+ transformer = unwrap_model(transformer)
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(
diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
index 0b9b9f993094..16a3863c881d 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
@@ -43,6 +43,7 @@
import shutil
from contextlib import nullcontext
from pathlib import Path
+from typing import Any
import numpy as np
import torch
@@ -74,13 +75,16 @@
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
from diffusers.training_utils import (
_collate_lora_metadata,
+ _to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
+ get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
+ wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -93,6 +97,9 @@
from diffusers.utils.torch_utils import is_compiled_module
+if getattr(torch, "distributed", None) is not None:
+ import torch.distributed as dist
+
if is_wandb_available():
import wandb
@@ -120,7 +127,7 @@ def save_model_card(
)
model_description = f"""
-# Flux DreamBooth LoRA - {repo_id}
+# Flux.2 DreamBooth LoRA - {repo_id}
@@ -339,7 +346,7 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
- required=True,
+ required=False,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -691,6 +698,7 @@ def parse_args(input_args=None):
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+ parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -827,15 +835,28 @@ def __init__(
dest_image = self.cond_images[i]
image_width, image_height = dest_image.size
if image_width * image_height > 1024 * 1024:
- dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
+ dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
image_width, image_height = dest_image.size
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
- dest_image = Flux2ImageProcessor.image_processor.preprocess(
+ image_processor = Flux2ImageProcessor()
+ dest_image = image_processor.preprocess(
dest_image, height=image_height, width=image_width, resize_mode="crop"
)
+ # Convert back to PIL
+ dest_image = dest_image.squeeze(0)
+ if dest_image.min() < 0:
+ dest_image = (dest_image + 1) / 2
+ dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
+
+ if dest_image.shape[0] == 1:
+ # Gray scale image
+ dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
+ else:
+ # RGB scale image: (C, H, W) -> (H, W, C)
+ dest_image = TF.to_pil_image(dest_image)
dest_image = exif_transpose(dest_image)
if not dest_image.mode == "RGB":
@@ -1156,7 +1177,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
- transformer.to(**transformer_to_kwargs)
+
+ is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
+ if not is_fsdp:
+ transformer.to(**transformer_to_kwargs)
+
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1200,17 +1225,42 @@ def unwrap_model(model):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
+ transformer_cls = type(unwrap_model(transformer))
+
+ # 1) Validate and pick the transformer model
+ modules_to_save: dict[str, Any] = {}
+ transformer_model = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), transformer_cls):
+ transformer_model = model
+ modules_to_save["transformer"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ if transformer_model is None:
+ raise ValueError("No transformer model found in 'models'")
+
+ # 2) Optionally gather FSDP state dict once
+ state_dict = accelerator.get_state_dict(model) if is_fsdp else None
+
+ # 3) Only main process materializes the LoRA state dict
+ transformer_lora_layers_to_save = None
if accelerator.is_main_process:
- transformer_lora_layers_to_save = None
- modules_to_save = {}
- for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
- modules_to_save["transformer"] = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ peft_kwargs = {}
+ if is_fsdp:
+ peft_kwargs["state_dict"] = state_dict
+
+ transformer_lora_layers_to_save = get_peft_model_state_dict(
+ unwrap_model(transformer_model) if is_fsdp else transformer_model,
+ **peft_kwargs,
+ )
+
+ if is_fsdp:
+ transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
- # make sure to pop weight so that corresponding model is not saved again
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1222,13 +1272,20 @@ def save_model_hook(models, weights, output_dir):
def load_model_hook(models, input_dir):
transformer_ = None
- while len(models) > 0:
- model = models.pop()
+ if not is_fsdp:
+ while len(models) > 0:
+ model = models.pop()
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+ else:
+ transformer_ = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ )
+ transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1419,9 +1476,9 @@ def _encode_single(prompt: str):
args.instance_prompt, text_encoding_pipeline
)
- validation_image = load_image(args.validation_image_path).convert("RGB")
- validation_kwargs = {"image": validation_image}
if args.validation_prompt is not None:
+ validation_image = load_image(args.validation_image_path).convert("RGB")
+ validation_kwargs = {"image": validation_image}
if args.remote_text_encoder:
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
else:
@@ -1430,6 +1487,21 @@ def _encode_single(prompt: str):
args.validation_prompt, text_encoding_pipeline
)
+ # Init FSDP for text encoder
+ if args.fsdp_text_encoder:
+ fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
+ text_encoder_fsdp = wrap_with_fsdp(
+ model=text_encoding_pipeline.text_encoder,
+ device=accelerator.device,
+ offload=args.offload,
+ limit_all_gathers=True,
+ use_orig_params=True,
+ fsdp_kwargs=fsdp_kwargs,
+ )
+
+ text_encoding_pipeline.text_encoder = text_encoder_fsdp
+ dist.barrier()
+
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1461,6 +1533,8 @@ def _encode_single(prompt: str):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
+ elif args.fsdp_text_encoder:
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1621,9 +1695,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
- cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
+ cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
+ cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
device=cond_model_input.device
)
+ cond_model_input_ids = cond_model_input_ids.view(
+ cond_model_input.shape[0], -1, model_input_ids.shape[-1]
+ )
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1650,6 +1728,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
+ orig_input_shape = packed_noisy_model_input.shape
+ orig_input_ids_shape = model_input_ids.shape
+
# concatenate the model inputs with the cond inputs
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
@@ -1668,7 +1749,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
img_ids=model_input_ids, # B, image_seq_len, 4
return_dict=False,
)[0]
- model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
+ model_pred = model_pred[:, : orig_input_shape[1], :]
+ model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
@@ -1700,7 +1782,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1
- if accelerator.is_main_process:
+ if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1759,15 +1841,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# Save the lora layers
accelerator.wait_for_everyone()
+
+ if is_fsdp:
+ transformer = unwrap_model(transformer)
+ state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
- transformer = unwrap_model(transformer)
- if args.bnb_quantization_config_path is None:
- if args.upcast_before_saving:
- transformer.to(torch.float32)
- else:
- transformer = transformer.to(weight_dtype)
- transformer_lora_layers = get_peft_model_state_dict(transformer)
+ if is_fsdp:
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ state_dict = {
+ k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+ else:
+ state_dict = {
+ k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+
+ transformer_lora_layers = get_peft_model_state_dict(
+ transformer,
+ state_dict=state_dict,
+ )
+ transformer_lora_layers = {
+ k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
+ for k, v in transformer_lora_layers.items()
+ }
+
+ else:
+ transformer = unwrap_model(transformer)
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(
diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py
new file mode 100644
index 000000000000..278c25900a3a
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py
@@ -0,0 +1,1942 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
+
+import argparse
+import copy
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import BatchSampler
+from torchvision import transforms
+from torchvision.transforms import functional as TF
+from tqdm.auto import tqdm
+from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
+
+import diffusers
+from diffusers import (
+ AutoencoderKLFlux2,
+ BitsAndBytesConfig,
+ FlowMatchEulerDiscreteScheduler,
+ Flux2KleinPipeline,
+ Flux2Transformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ _collate_lora_metadata,
+ _to_cpu_contiguous,
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ find_nearest_bucket,
+ free_memory,
+ get_fsdp_kwargs_from_accelerator,
+ offload_models,
+ parse_buckets_string,
+ wrap_with_fsdp,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if getattr(torch, "distributed", None) is not None:
+ import torch.distributed as dist
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.37.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+ quant_training=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Flux.2 [Klein] DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).
+
+Quant training? {quant_training}
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux2-klein",
+ "flux2-klein-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(dtype=torch_dtype)
+ pipeline.enable_model_cpu_offload()
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+
+ images = []
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ prompt_embeds=pipeline_args["prompt_embeds"],
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ free_memory()
+
+ return images
+
+
+def module_filter_fn(mod: torch.nn.Module, fqn: str):
+ # don't convert the output module
+ if fqn == "proj_out":
+ return False
+ # don't convert linear modules with weight dimensions not divisible by 16
+ if isinstance(mod, torch.nn.Linear):
+ if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
+ return False
+ return True
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--bnb_quantization_config_path",
+ type=str,
+ default=None,
+ help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
+ )
+ parser.add_argument(
+ "--do_fp8_training",
+ action="store_true",
+ help="if we are doing FP8 training.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=512,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--text_encoder_out_layers",
+ type=int,
+ nargs="+",
+ default=[10, 20, 30],
+ help="Text encoder hidden layers to compute the final text embeddings.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--skip_final_inference",
+ default=False,
+ action="store_true",
+ help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
+ )
+ parser.add_argument(
+ "--final_validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=4,
+ help="LoRA alpha to be used for additional scaling.",
+ )
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
+
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--aspect_ratio_buckets",
+ type=str,
+ default=None,
+ help=(
+ "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
+ "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
+ "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+ parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+ if args.do_fp8_training and args.bnb_quantization_config_path:
+ raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ buckets=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ self.buckets = buckets
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ for i, image in enumerate(self.instance_images):
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ width, height = image.size
+
+ # Find the closest bucket
+ bucket_idx = find_nearest_bucket(height, width, self.buckets)
+ target_height, target_width = self.buckets[bucket_idx]
+ self.size = (target_height, target_width)
+
+ # based on the bucket assignment, define the transformations
+ image = self.train_transform(
+ image,
+ size=self.size,
+ center_crop=args.center_crop,
+ random_flip=args.random_flip,
+ )
+ self.pixel_values.append((image, bucket_idx))
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["bucket_idx"] = bucket_idx
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False):
+ # 1. Resize (deterministic)
+ resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ image = resize(image)
+
+ # 2. Crop: either center or SAME random crop
+ if center_crop:
+ crop = transforms.CenterCrop(size)
+ image = crop(image)
+ else:
+ # get_params returns (i, j, h, w)
+ i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
+ image = TF.crop(image, i, j, h, w)
+
+ # 3. Random horizontal flip with the SAME coin flip
+ if random_flip:
+ do_flip = random.random() < 0.5
+ if do_flip:
+ image = TF.hflip(image)
+
+ # 4. ToTensor + Normalize (deterministic)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+ image = normalize(to_tensor(image))
+
+ return image
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class BucketBatchSampler(BatchSampler):
+ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
+ if not isinstance(drop_last, bool):
+ raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ # Group indices by bucket
+ self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
+ for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
+ self.bucket_indices[bucket_idx].append(idx)
+
+ self.sampler_len = 0
+ self.batches = []
+
+ # Pre-generate batches for each bucket
+ for indices_in_bucket in self.bucket_indices:
+ # Shuffle indices within the bucket
+ random.shuffle(indices_in_bucket)
+ # Create batches
+ for i in range(0, len(indices_in_bucket), self.batch_size):
+ batch = indices_in_bucket[i : i + self.batch_size]
+ if len(batch) < self.batch_size and self.drop_last:
+ continue # Skip partial batch if drop_last is True
+ self.batches.append(batch)
+ self.sampler_len += 1 # Count the number of batches
+
+ def __iter__(self):
+ # Shuffle the order of the batches each epoch
+ random.shuffle(self.batches)
+ for batch in self.batches:
+ yield batch
+
+ def __len__(self):
+ return self.sampler_len
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `hf auth login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+ if args.do_fp8_training:
+ from torchao.float8 import Float8LinearConfig, convert_to_float8_training
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+
+ pipeline = Flux2KleinPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
+ images = pipeline(prompt=example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ free_memory()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer = Qwen2TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ revision=args.revision,
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKLFlux2.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)
+ latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
+ accelerator.device
+ )
+
+ quantization_config = None
+ if args.bnb_quantization_config_path is not None:
+ with open(args.bnb_quantization_config_path, "r") as f:
+ config_kwargs = json.load(f)
+ if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
+ config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
+ quantization_config = BitsAndBytesConfig(**config_kwargs)
+
+ transformer = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ quantization_config=quantization_config,
+ torch_dtype=weight_dtype,
+ )
+ if args.bnb_quantization_config_path is not None:
+ transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
+
+ text_encoder = Qwen3ForCausalLM.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder.requires_grad_(False)
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.set_attention_backend("_native_npu")
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype}
+ # flux vae is stable in bf16 so load it in weight_dtype to reduce memory
+ vae.to(**to_kwargs)
+ # we never offload the transformer to CPU, so we can just use the accelerator device
+ transformer_to_kwargs = (
+ {"device": accelerator.device}
+ if args.bnb_quantization_config_path is not None
+ else {"device": accelerator.device, "dtype": weight_dtype}
+ )
+
+ is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
+ if not is_fsdp:
+ transformer.to(**transformer_to_kwargs)
+
+ if args.do_fp8_training:
+ convert_to_float8_training(
+ transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
+ )
+
+ text_encoder.to(**to_kwargs)
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = Flux2KleinPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ scheduler=None,
+ revision=args.revision,
+ )
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ transformer_cls = type(unwrap_model(transformer))
+
+ # 1) Validate and pick the transformer model
+ modules_to_save: dict[str, Any] = {}
+ transformer_model = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), transformer_cls):
+ transformer_model = model
+ modules_to_save["transformer"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ if transformer_model is None:
+ raise ValueError("No transformer model found in 'models'")
+
+ # 2) Optionally gather FSDP state dict once
+ state_dict = accelerator.get_state_dict(model) if is_fsdp else None
+
+ # 3) Only main process materializes the LoRA state dict
+ transformer_lora_layers_to_save = None
+ if accelerator.is_main_process:
+ peft_kwargs = {}
+ if is_fsdp:
+ peft_kwargs["state_dict"] = state_dict
+
+ transformer_lora_layers_to_save = get_peft_model_state_dict(
+ unwrap_model(transformer_model) if is_fsdp else transformer_model,
+ **peft_kwargs,
+ )
+
+ if is_fsdp:
+ transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ Flux2KleinPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ if not is_fsdp:
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+ else:
+ transformer_ = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ )
+ transformer_.add_adapter(transformer_lora_config)
+
+ lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ if args.aspect_ratio_buckets is not None:
+ buckets = parse_buckets_string(args.aspect_ratio_buckets)
+ else:
+ buckets = [(args.resolution, args.resolution)]
+ logger.info(f"Using parsed aspect ratio buckets: {buckets}")
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ buckets=buckets,
+ )
+ batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ def compute_text_embeddings(prompt, text_encoding_pipeline):
+ with torch.no_grad():
+ prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
+ prompt=prompt,
+ max_sequence_length=args.max_sequence_length,
+ text_encoder_out_layers=args.text_encoder_out_layers,
+ )
+ return prompt_embeds, text_ids
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not train_dataset.custom_instance_prompts:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoding_pipeline
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ class_prompt_hidden_states, class_text_ids = compute_text_embeddings(
+ args.class_prompt, text_encoding_pipeline
+ )
+ validation_embeddings = {}
+ if args.validation_prompt is not None:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings(
+ args.validation_prompt, text_encoding_pipeline
+ )
+
+ # Init FSDP for text encoder
+ if args.fsdp_text_encoder:
+ fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
+ text_encoder_fsdp = wrap_with_fsdp(
+ model=text_encoding_pipeline.text_encoder,
+ device=accelerator.device,
+ offload=args.offload,
+ limit_all_gathers=True,
+ use_orig_params=True,
+ fsdp_kwargs=fsdp_kwargs,
+ )
+
+ text_encoding_pipeline.text_encoder = text_encoder_fsdp
+ dist.barrier()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ if not train_dataset.custom_instance_prompts:
+ prompt_embeds = instance_prompt_hidden_states
+ text_ids = instance_text_ids
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ text_ids = torch.cat([text_ids, class_text_ids], dim=0)
+
+ # if cache_latents is set to True, we encode images to latents and store them.
+ # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
+ # we encode them in advance as well.
+ precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
+ if precompute_latents:
+ prompt_embeds_cache = []
+ text_ids_cache = []
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ if args.cache_latents:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ if train_dataset.custom_instance_prompts:
+ if args.fsdp_text_encoder:
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
+ prompt_embeds_cache.append(prompt_embeds)
+ text_ids_cache.append(text_ids)
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ if args.cache_latents:
+ vae = vae.to("cpu")
+ del vae
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ del text_encoder, tokenizer
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux2-klein-lora"
+ args_cp = vars(args).copy()
+ args_cp["text_encoder_out_layers"] = str(args_cp["text_encoder_out_layers"])
+ accelerator.init_trackers(tracker_name, config=args_cp)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ prompts = batch["prompts"]
+
+ with accelerator.accumulate(models_to_accumulate):
+ if train_dataset.custom_instance_prompts:
+ prompt_embeds = prompt_embeds_cache[step]
+ text_ids = text_ids_cache[step]
+ else:
+ num_repeat_elements = len(prompts)
+ prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
+ text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].mode()
+ else:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.mode()
+
+ model_input = Flux2KleinPipeline._patchify_latents(model_input)
+ model_input = (model_input - latents_bn_mean) / latents_bn_std
+
+ model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device)
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # [B, C, H, W] -> [B, H*W, C]
+ packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)
+
+ # handle guidance
+ if transformer.config.guidance_embeds:
+ guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+ else:
+ guidance = None
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input, # (B, image_seq_len, C)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=model_input_ids, # B, image_seq_len, 4
+ return_dict=False,
+ )[0]
+ model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
+
+ model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids)
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process or is_fsdp:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = Flux2KleinPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_embeddings,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+
+ del pipeline
+ free_memory()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+
+ if is_fsdp:
+ transformer = unwrap_model(transformer)
+ state_dict = accelerator.get_state_dict(transformer)
+ if accelerator.is_main_process:
+ modules_to_save = {}
+ if is_fsdp:
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ state_dict = {
+ k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+ else:
+ state_dict = {
+ k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+
+ transformer_lora_layers = get_peft_model_state_dict(
+ transformer,
+ state_dict=state_dict,
+ )
+ transformer_lora_layers = {
+ k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
+ for k, v in transformer_lora_layers.items()
+ }
+
+ else:
+ transformer = unwrap_model(transformer)
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ modules_to_save["transformer"] = transformer
+
+ Flux2KleinPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ images = []
+ run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
+ should_run_final_inference = not args.skip_final_inference and run_validation
+ if should_run_final_inference:
+ pipeline = Flux2KleinPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_embeddings,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+ images = None
+ del pipeline
+ free_memory()
+
+ validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
+ quant_training = None
+ if args.do_fp8_training:
+ quant_training = "FP8 TorchAO"
+ elif args.bnb_quantization_config_path:
+ quant_training = "BitsandBytes"
+ save_model_card(
+ (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=validation_prompt,
+ repo_folder=args.output_dir,
+ quant_training=quant_training,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py
new file mode 100644
index 000000000000..28cbaf8f72e7
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py
@@ -0,0 +1,1889 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
+
+import argparse
+import copy
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import BatchSampler
+from torchvision import transforms
+from torchvision.transforms import functional as TF
+from tqdm.auto import tqdm
+from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
+
+import diffusers
+from diffusers import (
+ AutoencoderKLFlux2,
+ BitsAndBytesConfig,
+ FlowMatchEulerDiscreteScheduler,
+ Flux2KleinPipeline,
+ Flux2Transformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
+from diffusers.training_utils import (
+ _collate_lora_metadata,
+ _to_cpu_contiguous,
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ find_nearest_bucket,
+ free_memory,
+ get_fsdp_kwargs_from_accelerator,
+ offload_models,
+ parse_buckets_string,
+ wrap_with_fsdp,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+ load_image,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if getattr(torch, "distributed", None) is not None:
+ import torch.distributed as dist
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.37.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+ fp8_training=False,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Flux.2 [Klein] DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).
+
+FP8 training? {fp8_training}
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux2",
+ "flux2-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(dtype=torch_dtype)
+ pipeline.enable_model_cpu_offload()
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+
+ images = []
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ image=pipeline_args["image"],
+ prompt_embeds=pipeline_args["prompt_embeds"],
+ negative_prompt_embeds=pipeline_args["negative_prompt_embeds"],
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ free_memory()
+
+ return images
+
+
+def module_filter_fn(mod: torch.nn.Module, fqn: str):
+ # don't convert the output module
+ if fqn == "proj_out":
+ return False
+ # don't convert linear modules with weight dimensions not divisible by 16
+ if isinstance(mod, torch.nn.Linear):
+ if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
+ return False
+ return True
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--bnb_quantization_config_path",
+ type=str,
+ default=None,
+ help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
+ )
+ parser.add_argument(
+ "--do_fp8_training",
+ action="store_true",
+ help="if we are doing FP8 training.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--cond_image_column",
+ type=str,
+ default=None,
+ help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=False,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=512,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ help="path to an image that is used during validation as the condition image to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--skip_final_inference",
+ default=False,
+ action="store_true",
+ help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
+ )
+ parser.add_argument(
+ "--final_validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=4,
+ help="LoRA alpha to be used for additional scaling.",
+ )
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
+
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--aspect_ratio_buckets",
+ type=str,
+ default=None,
+ help=(
+ "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
+ "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
+ "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+ parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.cond_image_column is None:
+ raise ValueError(
+ "you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example."
+ )
+ else:
+ assert args.image_column is not None
+ assert args.caption_column is not None
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ buckets=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+
+ self.buckets = buckets
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.cond_image_column is not None and args.cond_image_column not in column_names:
+ raise ValueError(
+ f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+ cond_images = None
+ cond_image_column = args.cond_image_column
+ if cond_image_column is not None:
+ cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
+ assert len(instance_images) == len(cond_images)
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ self.cond_images = []
+ for i, img in enumerate(instance_images):
+ self.instance_images.extend(itertools.repeat(img, repeats))
+ if args.dataset_name is not None and cond_images is not None:
+ self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
+
+ self.pixel_values = []
+ self.cond_pixel_values = []
+ for i, image in enumerate(self.instance_images):
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ dest_image = None
+ if self.cond_images: # todo: take care of max area for buckets
+ dest_image = self.cond_images[i]
+ image_width, image_height = dest_image.size
+ if image_width * image_height > 1024 * 1024:
+ dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
+ image_width, image_height = dest_image.size
+
+ multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
+ image_width = (image_width // multiple_of) * multiple_of
+ image_height = (image_height // multiple_of) * multiple_of
+ image_processor = Flux2ImageProcessor()
+ dest_image = image_processor.preprocess(
+ dest_image, height=image_height, width=image_width, resize_mode="crop"
+ )
+ # Convert back to PIL
+ dest_image = dest_image.squeeze(0)
+ if dest_image.min() < 0:
+ dest_image = (dest_image + 1) / 2
+ dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
+
+ if dest_image.shape[0] == 1:
+ # Gray scale image
+ dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
+ else:
+ # RGB scale image: (C, H, W) -> (H, W, C)
+ dest_image = TF.to_pil_image(dest_image)
+
+ dest_image = exif_transpose(dest_image)
+ if not dest_image.mode == "RGB":
+ dest_image = dest_image.convert("RGB")
+
+ width, height = image.size
+
+ # Find the closest bucket
+ bucket_idx = find_nearest_bucket(height, width, self.buckets)
+ target_height, target_width = self.buckets[bucket_idx]
+ self.size = (target_height, target_width)
+
+ # based on the bucket assignment, define the transformations
+ image, dest_image = self.paired_transform(
+ image,
+ dest_image=dest_image,
+ size=self.size,
+ center_crop=args.center_crop,
+ random_flip=args.random_flip,
+ )
+ self.pixel_values.append((image, bucket_idx))
+ if dest_image is not None:
+ self.cond_pixel_values.append((dest_image, bucket_idx))
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["bucket_idx"] = bucket_idx
+ if self.cond_pixel_values:
+ dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
+ example["cond_images"] = dest_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ return example
+
+ def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
+ # 1. Resize (deterministic)
+ resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ image = resize(image)
+ if dest_image is not None:
+ dest_image = resize(dest_image)
+
+ # 2. Crop: either center or SAME random crop
+ if center_crop:
+ crop = transforms.CenterCrop(size)
+ image = crop(image)
+ if dest_image is not None:
+ dest_image = crop(dest_image)
+ else:
+ # get_params returns (i, j, h, w)
+ i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
+ image = TF.crop(image, i, j, h, w)
+ if dest_image is not None:
+ dest_image = TF.crop(dest_image, i, j, h, w)
+
+ # 3. Random horizontal flip with the SAME coin flip
+ if random_flip:
+ do_flip = random.random() < 0.5
+ if do_flip:
+ image = TF.hflip(image)
+ if dest_image is not None:
+ dest_image = TF.hflip(dest_image)
+
+ # 4. ToTensor + Normalize (deterministic)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+ image = normalize(to_tensor(image))
+ if dest_image is not None:
+ dest_image = normalize(to_tensor(dest_image))
+
+ return (image, dest_image) if dest_image is not None else (image, None)
+
+
+def collate_fn(examples):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ if any("cond_images" in example for example in examples):
+ cond_pixel_values = [example["cond_images"] for example in examples]
+ cond_pixel_values = torch.stack(cond_pixel_values)
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
+ batch.update({"cond_pixel_values": cond_pixel_values})
+ return batch
+
+
+class BucketBatchSampler(BatchSampler):
+ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
+ if not isinstance(drop_last, bool):
+ raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ # Group indices by bucket
+ self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
+ for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
+ self.bucket_indices[bucket_idx].append(idx)
+
+ self.sampler_len = 0
+ self.batches = []
+
+ # Pre-generate batches for each bucket
+ for indices_in_bucket in self.bucket_indices:
+ # Shuffle indices within the bucket
+ random.shuffle(indices_in_bucket)
+ # Create batches
+ for i in range(0, len(indices_in_bucket), self.batch_size):
+ batch = indices_in_bucket[i : i + self.batch_size]
+ if len(batch) < self.batch_size and self.drop_last:
+ continue # Skip partial batch if drop_last is True
+ self.batches.append(batch)
+ self.sampler_len += 1 # Count the number of batches
+
+ def __iter__(self):
+ # Shuffle the order of the batches each epoch
+ random.shuffle(self.batches)
+ for batch in self.batches:
+ yield batch
+
+ def __len__(self):
+ return self.sampler_len
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `hf auth login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+ if args.do_fp8_training:
+ from torchao.float8 import Float8LinearConfig, convert_to_float8_training
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer = Qwen2TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ revision=args.revision,
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKLFlux2.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)
+ latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
+ accelerator.device
+ )
+
+ quantization_config = None
+ if args.bnb_quantization_config_path is not None:
+ with open(args.bnb_quantization_config_path, "r") as f:
+ config_kwargs = json.load(f)
+ if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
+ config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
+ quantization_config = BitsAndBytesConfig(**config_kwargs)
+
+ transformer = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ quantization_config=quantization_config,
+ torch_dtype=weight_dtype,
+ )
+ if args.bnb_quantization_config_path is not None:
+ transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
+
+ text_encoder = Qwen3ForCausalLM.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder.requires_grad_(False)
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.set_attention_backend("_native_npu")
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype}
+ # flux vae is stable in bf16 so load it in weight_dtype to reduce memory
+ vae.to(**to_kwargs)
+ # we never offload the transformer to CPU, so we can just use the accelerator device
+ transformer_to_kwargs = (
+ {"device": accelerator.device}
+ if args.bnb_quantization_config_path is not None
+ else {"device": accelerator.device, "dtype": weight_dtype}
+ )
+
+ is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
+ if not is_fsdp:
+ transformer.to(**transformer_to_kwargs)
+
+ if args.do_fp8_training:
+ convert_to_float8_training(
+ transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
+ )
+
+ text_encoder.to(**to_kwargs)
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = Flux2KleinPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ scheduler=None,
+ revision=args.revision,
+ )
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ transformer_cls = type(unwrap_model(transformer))
+
+ # 1) Validate and pick the transformer model
+ modules_to_save: dict[str, Any] = {}
+ transformer_model = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), transformer_cls):
+ transformer_model = model
+ modules_to_save["transformer"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ if transformer_model is None:
+ raise ValueError("No transformer model found in 'models'")
+
+ # 2) Optionally gather FSDP state dict once
+ state_dict = accelerator.get_state_dict(model) if is_fsdp else None
+
+ # 3) Only main process materializes the LoRA state dict
+ transformer_lora_layers_to_save = None
+ if accelerator.is_main_process:
+ peft_kwargs = {}
+ if is_fsdp:
+ peft_kwargs["state_dict"] = state_dict
+
+ transformer_lora_layers_to_save = get_peft_model_state_dict(
+ unwrap_model(transformer_model) if is_fsdp else transformer_model,
+ **peft_kwargs,
+ )
+
+ if is_fsdp:
+ transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ Flux2KleinPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ if not is_fsdp:
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+ else:
+ transformer_ = Flux2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ )
+ transformer_.add_adapter(transformer_lora_config)
+
+ lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ if args.aspect_ratio_buckets is not None:
+ buckets = parse_buckets_string(args.aspect_ratio_buckets)
+ else:
+ buckets = [(args.resolution, args.resolution)]
+ logger.info(f"Using parsed aspect ratio buckets: {buckets}")
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ buckets=buckets,
+ )
+ batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=lambda examples: collate_fn(examples),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ def compute_text_embeddings(prompt, text_encoding_pipeline):
+ with torch.no_grad():
+ prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
+ prompt=prompt, max_sequence_length=args.max_sequence_length
+ )
+ return prompt_embeds, text_ids
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not train_dataset.custom_instance_prompts:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoding_pipeline
+ )
+
+ if args.validation_prompt is not None:
+ validation_image = load_image(args.validation_image).convert("RGB")
+ validation_kwargs = {"image": validation_image}
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ validation_kwargs["prompt_embeds"], _text_ids = compute_text_embeddings(
+ args.validation_prompt, text_encoding_pipeline
+ )
+ validation_kwargs["negative_prompt_embeds"], _text_ids = compute_text_embeddings(
+ "", text_encoding_pipeline
+ )
+
+ # Init FSDP for text encoder
+ if args.fsdp_text_encoder:
+ fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
+ text_encoder_fsdp = wrap_with_fsdp(
+ model=text_encoding_pipeline.text_encoder,
+ device=accelerator.device,
+ offload=args.offload,
+ limit_all_gathers=True,
+ use_orig_params=True,
+ fsdp_kwargs=fsdp_kwargs,
+ )
+
+ text_encoding_pipeline.text_encoder = text_encoder_fsdp
+ dist.barrier()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ if not train_dataset.custom_instance_prompts:
+ prompt_embeds = instance_prompt_hidden_states
+ text_ids = instance_text_ids
+
+ # if cache_latents is set to True, we encode images to latents and store them.
+ # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
+ # we encode them in advance as well.
+ precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
+ if precompute_latents:
+ prompt_embeds_cache = []
+ text_ids_cache = []
+ latents_cache = []
+ cond_latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ if args.cache_latents:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
+ if train_dataset.custom_instance_prompts:
+ if args.fsdp_text_encoder:
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
+ else:
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
+ prompt_embeds_cache.append(prompt_embeds)
+ text_ids_cache.append(text_ids)
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ if args.cache_latents:
+ vae = vae.to("cpu")
+ del vae
+
+ # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ del text_encoder, tokenizer
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux2-image2img-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ prompts = batch["prompts"]
+
+ with accelerator.accumulate(models_to_accumulate):
+ if train_dataset.custom_instance_prompts:
+ prompt_embeds = prompt_embeds_cache[step]
+ text_ids = text_ids_cache[step]
+ else:
+ num_repeat_elements = len(prompts)
+ prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
+ text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].mode()
+ cond_model_input = cond_latents_cache[step].mode()
+ else:
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
+
+ model_input = vae.encode(pixel_values).latent_dist.mode()
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
+
+ model_input = Flux2KleinPipeline._patchify_latents(model_input)
+ model_input = (model_input - latents_bn_mean) / latents_bn_std
+
+ cond_model_input = Flux2KleinPipeline._patchify_latents(cond_model_input)
+ cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
+
+ model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device)
+ cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
+ cond_model_input_ids = Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to(
+ device=cond_model_input.device
+ )
+ cond_model_input_ids = cond_model_input_ids.view(
+ cond_model_input.shape[0], -1, model_input_ids.shape[-1]
+ )
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # [B, C, H, W] -> [B, H*W, C]
+ # concatenate the model inputs with the cond inputs
+ packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)
+ packed_cond_model_input = Flux2KleinPipeline._pack_latents(cond_model_input)
+ orig_input_shape = packed_noisy_model_input.shape
+ orig_input_ids_shape = model_input_ids.shape
+
+ # concatenate the model inputs with the cond inputs
+ packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
+ model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
+
+ # handle guidance
+ if transformer.config.guidance_embeds:
+ guidance = torch.full([1], args.guidance_scale, device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+ else:
+ guidance = None
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input, # (B, image_seq_len, C)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=model_input_ids, # B, image_seq_len, 4
+ return_dict=False,
+ )[0]
+ # pruning the condition information
+ model_pred = model_pred[:, : orig_input_shape[1], :]
+ model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
+
+ model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids)
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process or is_fsdp:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = Flux2KleinPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=None,
+ tokenizer=None,
+ transformer=unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_kwargs,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+
+ del pipeline
+ free_memory()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+
+ if is_fsdp:
+ transformer = unwrap_model(transformer)
+ state_dict = accelerator.get_state_dict(transformer)
+ if accelerator.is_main_process:
+ modules_to_save = {}
+ if is_fsdp:
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ state_dict = {
+ k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+ else:
+ state_dict = {
+ k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
+ }
+
+ transformer_lora_layers = get_peft_model_state_dict(
+ transformer,
+ state_dict=state_dict,
+ )
+ transformer_lora_layers = {
+ k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
+ for k, v in transformer_lora_layers.items()
+ }
+
+ else:
+ transformer = unwrap_model(transformer)
+ if args.bnb_quantization_config_path is None:
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ modules_to_save["transformer"] = transformer
+
+ Flux2KleinPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ images = []
+ run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
+ should_run_final_inference = not args.skip_final_inference and run_validation
+ if should_run_final_inference:
+ pipeline = Flux2KleinPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=validation_kwargs,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+ del pipeline
+ free_memory()
+
+ validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
+ save_model_card(
+ (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=validation_prompt,
+ repo_folder=args.output_dir,
+ fp8_training=args.do_fp8_training,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py
index 53b01bf0cfc8..ea9b137b0acd 100644
--- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py
+++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py
@@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
height=model_input.shape[3],
width=model_input.shape[4],
)
- print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
model_pred = transformer(
hidden_states=packed_noisy_model_input,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=timesteps / 1000,
img_shapes=img_shapes,
- txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)[0]
model_pred = QwenImagePipeline._unpack_latents(
diff --git a/examples/research_projects/lpl/README.md b/examples/research_projects/lpl/README.md
new file mode 100644
index 000000000000..a69fead50893
--- /dev/null
+++ b/examples/research_projects/lpl/README.md
@@ -0,0 +1,157 @@
+# Latent Perceptual Loss (LPL) for Stable Diffusion XL
+
+This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.
+
+## Overview
+
+LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:
+
+- Loss of fine details in generated images
+- Inconsistent image quality
+- Structural artifacts
+- Reduced sharpness and realism
+
+LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:
+
+- Improved image quality and consistency (6-20% FID improvement)
+- Better preservation of fine details
+- More stable training, especially at high noise levels
+- Better handling of structural information
+- Sharper and more realistic textures
+
+## Implementation Details
+
+The LPL implementation follows the paper's methodology and includes several key features:
+
+1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:
+ - Middle block features
+ - Up block features (configurable number of blocks)
+ - Proper gradient checkpointing for memory efficiency
+ - Features are extracted only for timesteps below the threshold (high SNR)
+
+2. **Feature Normalization**: Multiple normalization options as validated in the paper:
+ - `default`: Normalize each feature map independently
+ - `shared`: Cross-normalize features using target statistics (recommended)
+ - `batch`: Batch-wise normalization
+
+3. **Outlier Handling**: Optional removal of outliers in feature maps using:
+ - Quantile-based filtering (2% quantiles)
+ - Morphological operations (opening/closing)
+ - Adaptive thresholding based on standard deviation
+
+4. **Loss Types**:
+ - MSE loss (default)
+ - L1 loss
+ - Optional power law weighting (2^(-i) for layer i)
+
+## Usage
+
+To use LPL in your training, add the following arguments to your training command:
+
+```bash
+python examples/research_projects/lpl/train_sdxl_lpl.py \
+ --use_lpl \
+ --lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended)
+ --lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR)
+ --lpl_loss_type mse \ # Loss type: "mse" or "l1"
+ --lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch"
+ --lpl_pow_law \ # Use power law weighting for layers
+ --lpl_num_blocks 4 \ # Number of up blocks to use (1-4)
+ --lpl_remove_outliers \ # Remove outliers in feature maps
+ --lpl_scale \ # Scale LPL loss by noise level weights
+ --lpl_start 0 \ # Step to start applying LPL
+ # ... other training arguments ...
+```
+
+### Key Parameters
+
+- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.
+- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.
+- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.
+- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper.
+- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.
+- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.
+- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.
+- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.
+- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.
+
+## Recommendations
+
+1. **Starting Point** (based on paper results):
+ ```bash
+ --use_lpl \
+ --lpl_weight 1.0 \
+ --lpl_t_threshold 200 \
+ --lpl_loss_type mse \
+ --lpl_norm_type shared \
+ --lpl_pow_law \
+ --lpl_num_blocks 4 \
+ --lpl_remove_outliers \
+ --lpl_scale
+ ```
+
+2. **Memory Efficiency**:
+ - Use `--gradient_checkpointing` for memory efficiency (enabled by default)
+ - Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)
+ - Consider using `--lpl_scale` to focus on more important timesteps
+ - Features are extracted only for timesteps below threshold to save memory
+
+3. **Quality vs Speed**:
+ - Higher `lpl_weight` (1.0-2.0) for better quality
+ - Lower `lpl_t_threshold` (100-200) for faster training
+ - Use `lpl_remove_outliers` for more stable training
+ - `lpl_norm_type shared` provides best quality/speed trade-off
+
+## Technical Details
+
+### Feature Extraction
+
+The LPL implementation extracts features from the VAE decoder in the following order:
+1. Middle block output
+2. Up block outputs (configurable number of blocks)
+
+Each feature map is processed with:
+1. Optional outlier removal (2% quantiles, morphological operations)
+2. Feature normalization (shared statistics recommended)
+3. Loss calculation (MSE or L1)
+4. Optional power law weighting (2^(-i) for layer i)
+
+### Loss Calculation
+
+For each feature map:
+1. Features are normalized according to the chosen strategy
+2. Loss is calculated between normalized features
+3. Outliers are masked out (if enabled)
+4. Loss is weighted by layer depth (if power law enabled)
+5. Final loss is averaged across all layers
+
+### Memory Considerations
+
+- Gradient checkpointing is used by default
+- Features are extracted only for timesteps below the threshold
+- Outlier removal is done in-place to save memory
+- Feature normalization is done efficiently using vectorized operations
+- Memory usage scales linearly with number of blocks used
+
+## Results
+
+Based on the paper's findings, LPL provides:
+- 6-20% improvement in FID scores
+- Better preservation of fine details
+- More realistic textures and structures
+- Improved consistency across different resolutions
+- Better performance on both small and large datasets
+
+## Citation
+
+If you use this implementation in your research, please cite:
+
+```bibtex
+@inproceedings{berrada2025boosting,
+ title={Boosting Latent Diffusion with Perceptual Objectives},
+ author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},
+ booktitle={The Thirteenth International Conference on Learning Representations},
+ year={2025},
+ url={https://openreview.net/forum?id=y4DtzADzd1}
+}
+```
diff --git a/examples/research_projects/lpl/lpl_loss.py b/examples/research_projects/lpl/lpl_loss.py
new file mode 100644
index 000000000000..de14a4d8d5aa
--- /dev/null
+++ b/examples/research_projects/lpl/lpl_loss.py
@@ -0,0 +1,215 @@
+# Copyright 2025 Berrada et al.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def normalize_tensor(in_feat, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
+ return in_feat / (norm_factor + eps)
+
+
+def cross_normalize(input, target, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))
+ return input / (norm_factor + eps), target / (norm_factor + eps)
+
+
+def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):
+ opening = int(np.ceil(opening / down_f))
+ closing = int(np.ceil(closing / down_f))
+ if opening == 2:
+ opening = 3
+ if closing == 2:
+ closing = 1
+
+ # replace quantile with kth value here.
+ feat_flat = feat.flatten(-2, -1)
+ k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))
+ q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]
+ q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]
+
+ m = 2 * feat_flat.std(-1)[..., None, None].detach()
+ mask = (q1 - m < feat) * (feat < q2 + m)
+
+ # dilate the mask.
+ mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing
+ mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening
+ feat = feat * mask
+ return mask, feat
+
+
+class LatentPerceptualLoss(nn.Module):
+ def __init__(
+ self,
+ vae,
+ loss_type="mse",
+ grad_ckpt=True,
+ pow_law=False,
+ norm_type="default",
+ num_mid_blocks=4,
+ feature_type="feature",
+ remove_outliers=True,
+ ):
+ super().__init__()
+ self.vae = vae
+ self.decoder = self.vae.decoder
+ # Store scaling factors as tensors on the correct device
+ device = next(self.vae.parameters()).device
+
+ # Get scaling factors with proper defaults and handle None values
+ scale_factor = getattr(self.vae.config, "scaling_factor", None)
+ shift_factor = getattr(self.vae.config, "shift_factor", None)
+
+ # Convert to tensors with proper defaults
+ self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
+ self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
+
+ self.gradient_checkpointing = grad_ckpt
+ self.pow_law = pow_law
+ self.norm_type = norm_type.lower()
+ self.outlier_mask = remove_outliers
+ self.last_feature_stats = [] # Store feature statistics for logging
+
+ assert feature_type in ["feature", "image"]
+ self.feature_type = feature_type
+
+ assert self.norm_type in ["default", "shared", "batch"]
+ assert num_mid_blocks >= 0 and num_mid_blocks <= 4
+ self.n_blocks = num_mid_blocks
+
+ assert loss_type in ["mse", "l1"]
+ if loss_type == "mse":
+ self.loss_fn = nn.MSELoss(reduction="none")
+ elif loss_type == "l1":
+ self.loss_fn = nn.L1Loss(reduction="none")
+
+ def get_features(self, z, latent_embeds=None, disable_grads=False):
+ with torch.set_grad_enabled(not disable_grads):
+ if self.gradient_checkpointing and not disable_grads:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ features = []
+ upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
+ sample = z
+ sample = self.decoder.conv_in(sample)
+
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.decoder.mid_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ sample = sample.to(upscale_dtype)
+ features.append(sample)
+
+ # up
+ for up_block in self.decoder.up_blocks[: self.n_blocks]:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ features.append(sample)
+ return features
+ else:
+ features = []
+ upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
+ sample = z
+ sample = self.decoder.conv_in(sample)
+
+ # middle
+ sample = self.decoder.mid_block(sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+ features.append(sample)
+
+ # up
+ for up_block in self.decoder.up_blocks[: self.n_blocks]:
+ sample = up_block(sample, latent_embeds)
+ features.append(sample)
+ return features
+
+ def get_loss(self, input, target, get_hist=False):
+ if self.feature_type == "feature":
+ inp_f = self.get_features(self.shift + input / self.scale)
+ tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
+ losses = []
+ self.last_feature_stats = [] # Reset feature stats
+
+ for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
+ my = torch.ones_like(y).bool()
+ outlier_ratio = 0.0
+
+ if self.outlier_mask:
+ with torch.no_grad():
+ if i == 2:
+ my, y = remove_outliers(y, down_f=2)
+ outlier_ratio = 1.0 - my.float().mean().item()
+ elif i in [3, 4, 5]:
+ my, y = remove_outliers(y, down_f=1)
+ outlier_ratio = 1.0 - my.float().mean().item()
+
+ # Store feature statistics before normalization
+ with torch.no_grad():
+ stats = {
+ "mean": y.mean().item(),
+ "std": y.std().item(),
+ "outlier_ratio": outlier_ratio,
+ }
+ self.last_feature_stats.append(stats)
+
+ # normalize feature tensors
+ if self.norm_type == "default":
+ x = normalize_tensor(x)
+ y = normalize_tensor(y)
+ elif self.norm_type == "shared":
+ x, y = cross_normalize(x, y, eps=1e-6)
+
+ term_loss = self.loss_fn(x, y) * my
+ # reduce loss term
+ loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0
+ term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))
+ losses.append(term_loss.mean((1,)))
+
+ if get_hist:
+ return losses
+ else:
+ loss = sum(losses)
+ return loss / len(inp_f)
+ elif self.feature_type == "image":
+ inp_f = self.vae.decode(input / self.scale).sample
+ tar_f = self.vae.decode(target / self.scale).sample
+ return F.mse_loss(inp_f, tar_f)
+
+ def get_first_conv(self, z):
+ sample = self.decoder.conv_in(z)
+ return sample
+
+ def get_first_block(self, z):
+ sample = self.decoder.conv_in(z)
+ sample = self.decoder.mid_block(sample)
+ for resnet in self.decoder.up_blocks[0].resnets:
+ sample = resnet(sample, None)
+ return sample
+
+ def get_first_layer(self, input, target, target_layer="conv"):
+ if target_layer == "conv":
+ feat_in = self.get_first_conv(input)
+ with torch.no_grad():
+ feat_tar = self.get_first_conv(target)
+ else:
+ feat_in = self.get_first_block(input)
+ with torch.no_grad():
+ feat_tar = self.get_first_block(target)
+
+ feat_in, feat_tar = cross_normalize(feat_in, feat_tar)
+
+ return F.mse_loss(feat_in, feat_tar, reduction="mean")
diff --git a/examples/research_projects/lpl/train_sdxl_lpl.py b/examples/research_projects/lpl/train_sdxl_lpl.py
new file mode 100644
index 000000000000..4c472c8871c0
--- /dev/null
+++ b/examples/research_projects/lpl/train_sdxl_lpl.py
@@ -0,0 +1,1622 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""LPL training script for Stable Diffusion XL for text2image."""
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import re
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
+from datasets import concatenate_datasets, load_dataset
+from huggingface_hub import create_repo, upload_folder
+from lpl_loss import LatentPerceptualLoss
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+if is_torch_npu_available():
+ import torch_npu
+
+ torch.npu.config.allow_internal_format = False
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+# Global dictionary to store intermediate features from hooks
+hook_features: Dict[str, torch.Tensor] = {}
+
+
+def get_intermediate_features_hook(name: str):
+ """Creates a hook function that saves the output of a layer."""
+
+ def hook(model, input, output):
+ # Some layers might return tuples (e.g., attention blocks)
+ # We are usually interested in the first element (hidden states)
+ if isinstance(output, tuple):
+ hook_features[name] = output[0]
+ else:
+ hook_features[name] = output
+
+ return hook
+
+
+def clear_hook_features():
+ """Clears the global feature dictionary."""
+ global hook_features
+ hook_features = {}
+
+
+def normalize_features(
+ feat1: torch.Tensor, feat2: torch.Tensor, eps: float = 1e-6
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Normalizes feat1 and feat2 using the statistics of feat2 (predicted features).
+ Normalization is done per-channel.
+ """
+ # Calculate stats over spatial dimensions (H, W)
+ dims = tuple(range(2, feat2.ndim)) # Dims to reduce over (usually 2, 3 for H, W)
+ mean = torch.mean(feat2, dim=dims, keepdim=True)
+ std = torch.std(feat2, dim=dims, keepdim=True) + eps
+
+ feat1_norm = (feat1 - mean) / std
+ feat2_norm = (feat2 - mean) / std
+ return feat1_norm, feat2_norm
+
+
+def get_decoder_layer_names(decoder: nn.Module) -> List[str]:
+ """Helper to get potential layer names for hooks in the VAE decoder."""
+ layer_names = []
+ for name, module in decoder.named_modules():
+ # Example: Target ResnetBlocks and potentially UpBlocks
+ if isinstance(module, (diffusers.models.resnet.ResnetBlock2D, diffusers.models.unet_2d_blocks.UpBlock2D)):
+ # Filter out redundant names if UpBlock contains ResnetBlocks already named
+ is_child = any(
+ name.startswith(parent + ".")
+ for parent in layer_names
+ if isinstance(decoder.get_submodule(parent), diffusers.models.unet_2d_blocks.UpBlock2D)
+ )
+ if not is_child:
+ layer_names.append(name)
+ # A basic default selection if complex logic fails
+ if not layer_names:
+ layer_names = [
+ name for name, module in decoder.named_modules() if re.match(r"up_blocks\.\d+\.resnets\.\d+$", name)
+ ]
+ return layer_names
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ validation_prompt: str = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+ vae_path: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
+{img_str}
+
+Special VAE used for training: {vae_path}.
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="LPL based training script of Stable Diffusion XL.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sdxl-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--timestep_bias_strategy",
+ type=str,
+ default="none",
+ choices=["earlier", "later", "range", "none"],
+ help=(
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
+ " Choices: ['earlier', 'later', 'range', 'none']."
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_multiplier",
+ type=float,
+ default=1.0,
+ help=(
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_begin",
+ type=int,
+ default=0,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
+ " Defaults to zero, which equates to having no specific bias."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_end",
+ type=int,
+ default=1000,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_portion",
+ type=float,
+ default=0.25,
+ help=(
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
+ " whether the biased portions are in the earlier or later timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+
+ parser.add_argument(
+ "--use_lpl",
+ action="store_true",
+ help="Whether to use Latent Perceptual Loss (LPL). Increases memory usage.",
+ )
+ parser.add_argument(
+ "--lpl_weight",
+ type=float,
+ default=1.0,
+ help="Weight for the Latent Perceptual Loss.",
+ )
+ parser.add_argument(
+ "--lpl_t_threshold",
+ type=int,
+ default=200,
+ help="Apply LPL only for timesteps t < lpl_t_threshold. Corresponds to high SNR.",
+ )
+ parser.add_argument(
+ "--lpl_loss_type",
+ type=str,
+ default="mse",
+ choices=["mse", "l1"],
+ help="Type of loss to use for LPL.",
+ )
+ parser.add_argument(
+ "--lpl_norm_type",
+ type=str,
+ default="default",
+ choices=["default", "shared", "batch"],
+ help="Type of normalization to use for LPL features.",
+ )
+ parser.add_argument(
+ "--lpl_pow_law",
+ action="store_true",
+ help="Whether to use power law weighting for LPL layers.",
+ )
+ parser.add_argument(
+ "--lpl_num_blocks",
+ type=int,
+ default=4,
+ help="Number of up blocks to use for LPL feature extraction.",
+ )
+ parser.add_argument(
+ "--lpl_remove_outliers",
+ action="store_true",
+ help="Whether to remove outliers in LPL feature maps.",
+ )
+ parser.add_argument(
+ "--lpl_scale",
+ action="store_true",
+ help="Whether to scale LPL loss by noise level weights.",
+ )
+ parser.add_argument(
+ "--lpl_start",
+ type=int,
+ default=0,
+ help="Step to start applying LPL loss.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
+ prompt_embeds_list = []
+ prompt_batch = batch[caption_column]
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ return_dict=False,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
+
+
+def compute_vae_encodings(batch, vae):
+ images = batch.pop("pixel_values")
+ pixel_values = torch.stack(list(images))
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
+
+ with torch.no_grad():
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+
+ # There might have slightly performance improvement
+ # by changing model_input.cpu() to accelerator.gather(model_input)
+ return {"model_input": model_input.cpu()}
+
+
+def generate_timestep_weights(args, num_timesteps):
+ weights = torch.ones(num_timesteps)
+
+ # Determine the indices to bias
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
+
+ if args.timestep_bias_strategy == "later":
+ bias_indices = slice(-num_to_bias, None)
+ elif args.timestep_bias_strategy == "earlier":
+ bias_indices = slice(0, num_to_bias)
+ elif args.timestep_bias_strategy == "range":
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
+ range_begin = args.timestep_bias_begin
+ range_end = args.timestep_bias_end
+ if range_begin < 0:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
+ )
+ if range_end > num_timesteps:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
+ )
+ bias_indices = slice(range_begin, range_end)
+ else: # 'none' or any other string
+ return weights
+ if args.timestep_bias_multiplier <= 0:
+ return ValueError(
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
+ )
+
+ # Apply the bias
+ weights[bias_indices] *= args.timestep_bias_multiplier
+
+ # Normalize
+ weights /= weights.sum()
+
+ return weights
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Check for terminal SNR in combination with SNR Gamma
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # Freeze vae and text encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weights to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ unet.enable_npu_flash_attention()
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = unet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory. We will pre-compute the VAE encodings too.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ caption_column=args.caption_column,
+ )
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ new_fingerprint_for_vae = Hasher.hash((vae_path, args))
+ train_dataset_with_embeddings = train_dataset.map(
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
+ )
+ train_dataset_with_vae = train_dataset.map(
+ compute_vae_encodings_fn,
+ batched=True,
+ batch_size=args.train_batch_size,
+ new_fingerprint=new_fingerprint_for_vae,
+ )
+ precomputed_dataset = concatenate_datasets(
+ [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
+ )
+ precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
+
+ del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
+ del text_encoders, tokenizers
+ if not args.use_lpl:
+ del vae
+ gc.collect()
+
+ if is_torch_npu_available():
+ torch_npu.npu.empty_cache()
+ elif torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def collate_fn(examples):
+ model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
+
+ return {
+ "model_input": model_input,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ precomputed_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
+
+ if args.use_lpl:
+ lpl_fn = LatentPerceptualLoss(
+ vae=vae,
+ loss_type=args.lpl_loss_type,
+ grad_ckpt=args.gradient_checkpointing,
+ pow_law=args.lpl_pow_law,
+ norm_type=args.lpl_norm_type,
+ num_mid_blocks=args.lpl_num_blocks,
+ feature_type="feature",
+ remove_outliers=args.lpl_remove_outliers,
+ )
+ lpl_fn.to(accelerator.device)
+ else:
+ lpl_fn = None
+
+ # Function for unwrapping if torch.compile() was used in accelerate.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(precomputed_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ # Get scheduler alphas and sigmas for LPL z0_hat calculation
+ alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Sample noise that we'll add to the latents
+ model_input = batch["model_input"].to(accelerator.device)
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ if args.timestep_bias_strategy == "none":
+ # Sample a random timestep for each image without bias.
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ else:
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ elif noise_scheduler.config.prediction_type == "sample":
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
+ target = model_input
+ # We will have to subtract the noise residual from the prediction to get the target sample.
+ model_pred = model_pred - noise
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ lpl_loss_value = torch.tensor(0.0, device=accelerator.device)
+ if args.use_lpl and lpl_fn is not None and global_step >= args.lpl_start:
+ # Apply LPL only below the timestep threshold
+ lpl_mask = timesteps < args.lpl_t_threshold
+ if lpl_mask.any():
+ # Select samples that meet the threshold
+ masked_indices = torch.where(lpl_mask)[0]
+ z0_masked = model_input[masked_indices]
+ zt_masked = noisy_model_input[masked_indices]
+ t_masked = timesteps[masked_indices]
+ model_pred_masked = model_pred[masked_indices]
+
+ # Calculate z0_hat for the masked samples
+ alpha_t = alphas_cumprod[t_masked].sqrt().to(torch.float32)
+ sigma_t = (1 - alphas_cumprod[t_masked]).sqrt().to(torch.float32)
+ alpha_t = alpha_t.view(-1, 1, 1, 1)
+ sigma_t = sigma_t.view(-1, 1, 1, 1)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ z0_hat_masked = (zt_masked.float() - sigma_t * model_pred_masked.float()) / alpha_t
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ z0_hat_masked = alpha_t * zt_masked.float() - sigma_t * model_pred_masked.float()
+ else: # sample prediction
+ z0_hat_masked = model_pred_masked.float()
+
+ with accelerator.autocast():
+ lpl_loss_value = lpl_fn.get_loss(z0_hat_masked, z0_masked)
+
+ if args.lpl_scale:
+ if args.snr_gamma is not None:
+ # Use SNR-based weights if available
+ snr = compute_snr(noise_scheduler, t_masked)
+ snr_weights = torch.stack(
+ [snr, args.snr_gamma * torch.ones_like(t_masked)], dim=1
+ ).min(dim=1)[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ snr_weights = snr_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ snr_weights = snr_weights / (snr + 1)
+ lpl_loss_value = (lpl_loss_value * snr_weights).mean()
+ else:
+ # If no SNR weighting, just use mean
+ lpl_loss_value = lpl_loss_value.mean()
+ else:
+ lpl_loss_value = lpl_loss_value.mean()
+
+ # Combine losses
+ total_loss = loss + args.lpl_weight * lpl_loss_value
+
+ # Gather the losses across all processes for logging
+ avg_loss = accelerator.gather(total_loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(total_loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+
+ # Enhanced logging for LPL metrics
+ log_data = {
+ "train_loss": train_loss,
+ "diffusion_loss": loss.item(),
+ "learning_rate": lr_scheduler.get_last_lr()[0],
+ }
+
+ if args.use_lpl and lpl_fn is not None and global_step >= args.lpl_start:
+ if lpl_mask.any():
+ # LPL application statistics
+ log_data.update(
+ {
+ "lpl/loss": lpl_loss_value.item(),
+ "lpl/num_samples": lpl_mask.sum().item(),
+ "lpl/application_ratio": lpl_mask.float().mean().item(),
+ "lpl/weight": args.lpl_weight,
+ "lpl/weighted_loss": (args.lpl_weight * lpl_loss_value).item(),
+ }
+ )
+
+ # SNR statistics for LPL-applied samples
+ if args.snr_gamma is not None:
+ snr_values = snr[masked_indices]
+ log_data.update(
+ {
+ "lpl/snr_mean": snr_values.mean().item(),
+ "lpl/snr_std": snr_values.std().item(),
+ "lpl/snr_min": snr_values.min().item(),
+ "lpl/snr_max": snr_values.max().item(),
+ }
+ )
+
+ # Feature statistics if available
+ if hasattr(lpl_fn, "last_feature_stats"):
+ for layer_idx, stats in enumerate(lpl_fn.last_feature_stats):
+ log_data.update(
+ {
+ f"lpl/features/layer_{layer_idx}/mean": stats["mean"],
+ f"lpl/features/layer_{layer_idx}/std": stats["std"],
+ f"lpl/features/layer_{layer_idx}/outlier_ratio": stats.get(
+ "outlier_ratio", 0.0
+ ),
+ }
+ )
+
+ # Memory usage if available
+ if torch.cuda.is_available():
+ log_data.update(
+ {
+ "lpl/memory/allocated": torch.cuda.memory_allocated() / 1024**2, # MB
+ "lpl/memory/reserved": torch.cuda.memory_reserved() / 1024**2, # MB
+ }
+ )
+
+ # Log to accelerator
+ accelerator.log(log_data, step=global_step)
+
+ # Update progress bar with more metrics
+ progress_bar_logs = {
+ "loss": loss.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ }
+ if args.use_lpl and lpl_loss_value.item() > 0:
+ progress_bar_logs.update(
+ {
+ "lpl": lpl_loss_value.item(),
+ "lpl_ratio": lpl_mask.float().mean().item() if lpl_mask.any() else 0.0,
+ }
+ )
+ progress_bar.set_postfix(**progress_bar_logs)
+
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # create pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = (
+ torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ if args.seed is not None
+ else None
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with autocast_ctx:
+ images = [
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if is_torch_npu_available():
+ torch_npu.npu.empty_cache()
+ elif torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ # Serialize pipeline.
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unet,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline.save_pretrained(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = (
+ torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ )
+
+ with autocast_ctx:
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id=repo_id,
+ images=images,
+ validation_prompt=args.validation_prompt,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/onnxruntime/text_to_image/README.md b/examples/research_projects/onnxruntime/text_to_image/README.md
index f398f081663a..1d688471ba74 100644
--- a/examples/research_projects/onnxruntime/text_to_image/README.md
+++ b/examples/research_projects/onnxruntime/text_to_image/README.md
@@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode
___Note___:
-___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
+___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
## Running locally with PyTorch
diff --git a/examples/research_projects/sdxl_flax/sdxl_single.py b/examples/research_projects/sdxl_flax/sdxl_single.py
index 5b9b862d99b5..c3cbf6ca24f0 100644
--- a/examples/research_projects/sdxl_flax/sdxl_single.py
+++ b/examples/research_projects/sdxl_flax/sdxl_single.py
@@ -18,7 +18,7 @@
NUM_DEVICES = jax.device_count()
# 1. Let's start by downloading the model and loading it into our pipeline class
-# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
+# Adhering to JAX's functional approach, the model's parameters are returned separately and
# will have to be passed to the pipeline during inference
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
diff --git a/examples/server-async/utils/requestscopedpipeline.py b/examples/server-async/utils/requestscopedpipeline.py
index 57d1e2567169..9c3276c31c69 100644
--- a/examples/server-async/utils/requestscopedpipeline.py
+++ b/examples/server-async/utils/requestscopedpipeline.py
@@ -7,16 +7,12 @@
from diffusers.utils import logging
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
+from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper
logger = logging.get_logger(__name__)
-def safe_tokenize(tokenizer, *args, lock, **kwargs):
- with lock:
- return tokenizer(*args, **kwargs)
-
-
class RequestScopedPipeline:
DEFAULT_MUTABLE_ATTRS = [
"_all_hooks",
@@ -38,23 +34,40 @@ def __init__(
wrap_scheduler: bool = True,
):
self._base = pipeline
+
self.unet = getattr(pipeline, "unet", None)
self.vae = getattr(pipeline, "vae", None)
self.text_encoder = getattr(pipeline, "text_encoder", None)
self.components = getattr(pipeline, "components", None)
+ self.transformer = getattr(pipeline, "transformer", None)
+
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
+
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
+ self._vae_lock = threading.Lock()
+ self._image_lock = threading.Lock()
+
self._auto_detect_mutables = bool(auto_detect_mutables)
self._tensor_numel_threshold = int(tensor_numel_threshold)
-
self._auto_detected_attrs: List[str] = []
+ def _detect_kernel_pipeline(self, pipeline) -> bool:
+ kernel_indicators = [
+ "text_encoding_cache",
+ "memory_manager",
+ "enable_optimizations",
+ "_create_request_context",
+ "get_optimization_stats",
+ ]
+
+ return any(hasattr(pipeline, attr) for attr in kernel_indicators)
+
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
base_sched = getattr(self._base, "scheduler", None)
if base_sched is None:
@@ -70,11 +83,21 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
)
except Exception as e:
- logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
+ logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
try:
- return copy.deepcopy(wrapped_scheduler)
- except Exception as e:
- logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
+ if hasattr(wrapped_scheduler, "scheduler"):
+ try:
+ copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
+ return BaseAsyncScheduler(copied_scheduler)
+ except Exception:
+ return wrapped_scheduler
+ else:
+ copied_scheduler = copy.copy(wrapped_scheduler)
+ return BaseAsyncScheduler(copied_scheduler)
+ except Exception as e2:
+ logger.warning(
+ f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
+ )
return wrapped_scheduler
def _autodetect_mutables(self, max_attrs: int = 40):
@@ -86,6 +109,7 @@ def _autodetect_mutables(self, max_attrs: int = 40):
candidates: List[str] = []
seen = set()
+
for name in dir(self._base):
if name.startswith("__"):
continue
@@ -93,6 +117,7 @@ def _autodetect_mutables(self, max_attrs: int = 40):
continue
if name in ("to", "save_pretrained", "from_pretrained"):
continue
+
try:
val = getattr(self._base, name)
except Exception:
@@ -100,11 +125,9 @@ def _autodetect_mutables(self, max_attrs: int = 40):
import types
- # skip callables and modules
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
continue
- # containers -> candidate
if isinstance(val, (dict, list, set, tuple, bytearray)):
candidates.append(name)
seen.add(name)
@@ -205,6 +228,9 @@ def _is_tokenizer_component(self, component) -> bool:
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
+ def _should_wrap_tokenizers(self) -> bool:
+ return True
+
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
@@ -214,6 +240,25 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
local_pipe = copy.deepcopy(self._base)
+ try:
+ if (
+ hasattr(local_pipe, "vae")
+ and local_pipe.vae is not None
+ and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
+ ):
+ local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
+
+ if (
+ hasattr(local_pipe, "image_processor")
+ and local_pipe.image_processor is not None
+ and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
+ ):
+ local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
+ local_pipe.image_processor, self._image_lock
+ )
+ except Exception as e:
+ logger.debug(f"Could not wrap vae/image_processor: {e}")
+
if local_scheduler is not None:
try:
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
@@ -231,47 +276,42 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
self._clone_mutable_attrs(self._base, local_pipe)
- # 4) wrap tokenizers on the local pipe with the lock wrapper
- tokenizer_wrappers = {} # name -> original_tokenizer
- try:
- # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
- for name in dir(local_pipe):
- if "tokenizer" in name and not name.startswith("_"):
- tok = getattr(local_pipe, name, None)
- if tok is not None and self._is_tokenizer_component(tok):
- tokenizer_wrappers[name] = tok
- setattr(
- local_pipe,
- name,
- lambda *args, tok=tok, **kwargs: safe_tokenize(
- tok, *args, lock=self._tokenizer_lock, **kwargs
- ),
- )
-
- # b) wrap tokenizers in components dict
- if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
- for key, val in local_pipe.components.items():
- if val is None:
- continue
-
- if self._is_tokenizer_component(val):
- tokenizer_wrappers[f"components[{key}]"] = val
- local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
- tokenizer, *args, lock=self._tokenizer_lock, **kwargs
- )
+ original_tokenizers = {}
- except Exception as e:
- logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
+ if self._should_wrap_tokenizers():
+ try:
+ for name in dir(local_pipe):
+ if "tokenizer" in name and not name.startswith("_"):
+ tok = getattr(local_pipe, name, None)
+ if tok is not None and self._is_tokenizer_component(tok):
+ if not isinstance(tok, ThreadSafeTokenizerWrapper):
+ original_tokenizers[name] = tok
+ wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
+ setattr(local_pipe, name, wrapped_tokenizer)
+
+ if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
+ for key, val in local_pipe.components.items():
+ if val is None:
+ continue
+
+ if self._is_tokenizer_component(val):
+ if not isinstance(val, ThreadSafeTokenizerWrapper):
+ original_tokenizers[f"components[{key}]"] = val
+ wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
+ local_pipe.components[key] = wrapped_tokenizer
+
+ except Exception as e:
+ logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
result = None
cm = getattr(local_pipe, "model_cpu_offload_context", None)
+
try:
if callable(cm):
try:
with cm():
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except TypeError:
- # cm might be a context manager instance rather than callable
try:
with cm:
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
@@ -279,18 +319,18 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
else:
- # no offload context available — call directly
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
return result
finally:
try:
- for name, tok in tokenizer_wrappers.items():
+ for name, tok in original_tokenizers.items():
if name.startswith("components["):
key = name[len("components[") : -1]
- local_pipe.components[key] = tok
+ if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
+ local_pipe.components[key] = tok
else:
setattr(local_pipe, name, tok)
except Exception as e:
- logger.debug(f"Error restoring wrapped tokenizers: {e}")
+ logger.debug(f"Error restoring original tokenizers: {e}")
diff --git a/examples/server-async/utils/wrappers.py b/examples/server-async/utils/wrappers.py
new file mode 100644
index 000000000000..1e8474eabf3f
--- /dev/null
+++ b/examples/server-async/utils/wrappers.py
@@ -0,0 +1,86 @@
+class ThreadSafeTokenizerWrapper:
+ def __init__(self, tokenizer, lock):
+ self._tokenizer = tokenizer
+ self._lock = lock
+
+ self._thread_safe_methods = {
+ "__call__",
+ "encode",
+ "decode",
+ "tokenize",
+ "encode_plus",
+ "batch_encode_plus",
+ "batch_decode",
+ }
+
+ def __getattr__(self, name):
+ attr = getattr(self._tokenizer, name)
+
+ if name in self._thread_safe_methods and callable(attr):
+
+ def wrapped_method(*args, **kwargs):
+ with self._lock:
+ return attr(*args, **kwargs)
+
+ return wrapped_method
+
+ return attr
+
+ def __call__(self, *args, **kwargs):
+ with self._lock:
+ return self._tokenizer(*args, **kwargs)
+
+ def __setattr__(self, name, value):
+ if name.startswith("_"):
+ super().__setattr__(name, value)
+ else:
+ setattr(self._tokenizer, name, value)
+
+ def __dir__(self):
+ return dir(self._tokenizer)
+
+
+class ThreadSafeVAEWrapper:
+ def __init__(self, vae, lock):
+ self._vae = vae
+ self._lock = lock
+
+ def __getattr__(self, name):
+ attr = getattr(self._vae, name)
+ if name in {"decode", "encode", "forward"} and callable(attr):
+
+ def wrapped(*args, **kwargs):
+ with self._lock:
+ return attr(*args, **kwargs)
+
+ return wrapped
+ return attr
+
+ def __setattr__(self, name, value):
+ if name.startswith("_"):
+ super().__setattr__(name, value)
+ else:
+ setattr(self._vae, name, value)
+
+
+class ThreadSafeImageProcessorWrapper:
+ def __init__(self, proc, lock):
+ self._proc = proc
+ self._lock = lock
+
+ def __getattr__(self, name):
+ attr = getattr(self._proc, name)
+ if name in {"postprocess", "preprocess"} and callable(attr):
+
+ def wrapped(*args, **kwargs):
+ with self._lock:
+ return attr(*args, **kwargs)
+
+ return wrapped
+ return attr
+
+ def __setattr__(self, name, value):
+ if name.startswith("_"):
+ super().__setattr__(name, value)
+ else:
+ setattr(self._proc, name, value)
diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py
index 2973913fa215..a8fa6f87eee1 100644
--- a/scripts/convert_flux2_to_diffusers.py
+++ b/scripts/convert_flux2_to_diffusers.py
@@ -44,7 +44,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
-parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
+parser.add_argument("--dit_filename", default="flux2-dev.safetensors", type=str)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--dit", action="store_true")
parser.add_argument("--vae_dtype", type=str, default="fp32")
@@ -385,9 +385,9 @@ def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) ->
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
- if model_type == "test" or model_type == "dummy-flux2":
+ if model_type == "flux2-dev":
config = {
- "model_id": "diffusers-internal-dev/dummy-flux2",
+ "model_id": "black-forest-labs/FLUX.2-dev",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
@@ -405,6 +405,53 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "klein-4b":
+ config = {
+ "model_id": "diffusers-internal-dev/dummy0115",
+ "diffusers_config": {
+ "patch_size": 1,
+ "in_channels": 128,
+ "num_layers": 5,
+ "num_single_layers": 20,
+ "attention_head_dim": 128,
+ "num_attention_heads": 24,
+ "joint_attention_dim": 7680,
+ "timestep_guidance_channels": 256,
+ "mlp_ratio": 3.0,
+ "axes_dims_rope": (32, 32, 32, 32),
+ "rope_theta": 2000,
+ "eps": 1e-6,
+ "guidance_embeds": False,
+ },
+ }
+ rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
+ special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
+
+ elif model_type == "klein-9b":
+ config = {
+ "model_id": "diffusers-internal-dev/dummy0115",
+ "diffusers_config": {
+ "patch_size": 1,
+ "in_channels": 128,
+ "num_layers": 8,
+ "num_single_layers": 24,
+ "attention_head_dim": 128,
+ "num_attention_heads": 32,
+ "joint_attention_dim": 12288,
+ "timestep_guidance_channels": 256,
+ "mlp_ratio": 3.0,
+ "axes_dims_rope": (32, 32, 32, 32),
+ "rope_theta": 2000,
+ "eps": 1e-6,
+ "guidance_embeds": False,
+ },
+ }
+ rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
+ special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
+
+ else:
+ raise ValueError(f"Unknown model_type: {model_type}. Choose from: flux2-dev, klein-4b, klein-9b")
+
return config, rename_dict, special_keys_remap
@@ -447,7 +494,14 @@ def main(args):
if args.dit:
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
- transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
+
+ if "klein-4b" in args.dit_filename:
+ model_type = "klein-4b"
+ elif "klein-9b" in args.dit_filename:
+ model_type = "klein-9b"
+ else:
+ model_type = "flux2-dev"
+ transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, model_type)
if not args.full_pipe:
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
@@ -465,8 +519,15 @@ def main(args):
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)
+ if_distilled = "base" not in args.dit_filename
+
pipe = Flux2Pipeline(
- vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ if_distilled=if_distilled,
)
pipe.save_pretrained(args.output_path)
diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py
new file mode 100644
index 000000000000..5367113365a2
--- /dev/null
+++ b/scripts/convert_ltx2_to_diffusers.py
@@ -0,0 +1,886 @@
+import argparse
+import os
+from contextlib import nullcontext
+from typing import Any, Dict, Optional, Tuple
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
+
+from diffusers import (
+ AutoencoderKLLTX2Audio,
+ AutoencoderKLLTX2Video,
+ FlowMatchEulerDiscreteScheduler,
+ LTX2LatentUpsamplePipeline,
+ LTX2Pipeline,
+ LTX2VideoTransformer3DModel,
+)
+from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
+
+
+LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
+ # Input Patchify Projections
+ "patchify_proj": "proj_in",
+ "audio_patchify_proj": "audio_proj_in",
+ # Modulation Parameters
+ # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
+ # substrings of the other modulation parameters below
+ "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
+ "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
+ "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
+ "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
+ # Transformer Blocks
+ # Per-Block Cross Attention Modulatin Parameters
+ "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
+ "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
+ # Attention QK Norms
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+}
+
+LTX_2_0_VIDEO_VAE_RENAME_DICT = {
+ # Encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.1",
+ "down_blocks.3": "down_blocks.1.downsamplers.0",
+ "down_blocks.4": "down_blocks.2",
+ "down_blocks.5": "down_blocks.2.downsamplers.0",
+ "down_blocks.6": "down_blocks.3",
+ "down_blocks.7": "down_blocks.3.downsamplers.0",
+ "down_blocks.8": "mid_block",
+ # Decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ # Common
+ # For all 3D ResNets
+ "res_blocks": "resnets",
+ "per_channel_statistics.mean-of-means": "latents_mean",
+ "per_channel_statistics.std-of-means": "latents_std",
+}
+
+LTX_2_0_AUDIO_VAE_RENAME_DICT = {
+ "per_channel_statistics.mean-of-means": "latents_mean",
+ "per_channel_statistics.std-of-means": "latents_std",
+}
+
+LTX_2_0_VOCODER_RENAME_DICT = {
+ "ups": "upsamplers",
+ "resblocks": "resnets",
+ "conv_pre": "conv_in",
+ "conv_post": "conv_out",
+}
+
+LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
+ "video_embeddings_connector": "video_connector",
+ "audio_embeddings_connector": "audio_connector",
+ "transformer_1d_blocks": "transformer_blocks",
+ # Attention QK Norms
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+}
+
+
+def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
+ state_dict.pop(key)
+
+
+def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
+ # Skip if not a weight, bias
+ if ".weight" not in key and ".bias" not in key:
+ return
+
+ if key.startswith("adaln_single."):
+ new_key = key.replace("adaln_single.", "time_embed.")
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+
+ if key.startswith("audio_adaln_single."):
+ new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+
+ return
+
+
+def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
+ if key.startswith("per_channel_statistics"):
+ new_key = ".".join(["decoder", key])
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+
+ return
+
+
+LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "video_embeddings_connector": remove_keys_inplace,
+ "audio_embeddings_connector": remove_keys_inplace,
+ "adaln_single": convert_ltx2_transformer_adaln_single,
+}
+
+LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
+ "connectors.": "",
+ "video_embeddings_connector": "video_connector",
+ "audio_embeddings_connector": "audio_connector",
+ "transformer_1d_blocks": "transformer_blocks",
+ "text_embedding_projection.aggregate_embed": "text_proj_in",
+ # Attention QK Norms
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+}
+
+LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
+ "per_channel_statistics.channel": remove_keys_inplace,
+ "per_channel_statistics.mean-of-stds": remove_keys_inplace,
+}
+
+LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
+
+LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
+
+
+def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ connector_prefixes = (
+ "video_embeddings_connector",
+ "audio_embeddings_connector",
+ "transformer_1d_blocks",
+ "text_embedding_projection.aggregate_embed",
+ "connectors.",
+ "video_connector",
+ "audio_connector",
+ "text_proj_in",
+ )
+
+ transformer_state_dict, connector_state_dict = {}, {}
+ for key, value in state_dict.items():
+ if key.startswith(connector_prefixes):
+ connector_state_dict[key] = value
+ else:
+ transformer_state_dict[key] = value
+
+ return transformer_state_dict, connector_state_dict
+
+
+def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
+ if version == "test":
+ # Produces a transformer of the same size as used in test_models_transformer_ltx2.py
+ config = {
+ "model_id": "diffusers-internal-dev/dummy-ltx2",
+ "diffusers_config": {
+ "in_channels": 4,
+ "out_channels": 4,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "cross_attention_dim": 16,
+ "vae_scale_factors": (8, 32, 32),
+ "pos_embed_max_pos": 20,
+ "base_height": 2048,
+ "base_width": 2048,
+ "audio_in_channels": 4,
+ "audio_out_channels": 4,
+ "audio_patch_size": 1,
+ "audio_patch_size_t": 1,
+ "audio_num_attention_heads": 2,
+ "audio_attention_head_dim": 4,
+ "audio_cross_attention_dim": 8,
+ "audio_scale_factor": 4,
+ "audio_pos_embed_max_pos": 20,
+ "audio_sampling_rate": 16000,
+ "audio_hop_length": 160,
+ "num_layers": 2,
+ "activation_fn": "gelu-approximate",
+ "qk_norm": "rms_norm_across_heads",
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "caption_channels": 16,
+ "attention_bias": True,
+ "attention_out_bias": True,
+ "rope_theta": 10000.0,
+ "rope_double_precision": False,
+ "causal_offset": 1,
+ "timestep_scale_multiplier": 1000,
+ "cross_attn_timestep_scale_multiplier": 1,
+ },
+ }
+ rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
+ special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif version == "2.0":
+ config = {
+ "model_id": "diffusers-internal-dev/new-ltx-model",
+ "diffusers_config": {
+ "in_channels": 128,
+ "out_channels": 128,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": 32,
+ "attention_head_dim": 128,
+ "cross_attention_dim": 4096,
+ "vae_scale_factors": (8, 32, 32),
+ "pos_embed_max_pos": 20,
+ "base_height": 2048,
+ "base_width": 2048,
+ "audio_in_channels": 128,
+ "audio_out_channels": 128,
+ "audio_patch_size": 1,
+ "audio_patch_size_t": 1,
+ "audio_num_attention_heads": 32,
+ "audio_attention_head_dim": 64,
+ "audio_cross_attention_dim": 2048,
+ "audio_scale_factor": 4,
+ "audio_pos_embed_max_pos": 20,
+ "audio_sampling_rate": 16000,
+ "audio_hop_length": 160,
+ "num_layers": 48,
+ "activation_fn": "gelu-approximate",
+ "qk_norm": "rms_norm_across_heads",
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "caption_channels": 3840,
+ "attention_bias": True,
+ "attention_out_bias": True,
+ "rope_theta": 10000.0,
+ "rope_double_precision": True,
+ "causal_offset": 1,
+ "timestep_scale_multiplier": 1000,
+ "cross_attn_timestep_scale_multiplier": 1000,
+ "rope_type": "split",
+ },
+ }
+ rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
+ special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
+ return config, rename_dict, special_keys_remap
+
+
+def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
+ if version == "test":
+ config = {
+ "model_id": "diffusers-internal-dev/dummy-ltx2",
+ "diffusers_config": {
+ "caption_channels": 16,
+ "text_proj_in_factor": 3,
+ "video_connector_num_attention_heads": 4,
+ "video_connector_attention_head_dim": 8,
+ "video_connector_num_layers": 1,
+ "video_connector_num_learnable_registers": None,
+ "audio_connector_num_attention_heads": 4,
+ "audio_connector_attention_head_dim": 8,
+ "audio_connector_num_layers": 1,
+ "audio_connector_num_learnable_registers": None,
+ "connector_rope_base_seq_len": 32,
+ "rope_theta": 10000.0,
+ "rope_double_precision": False,
+ "causal_temporal_positioning": False,
+ },
+ }
+ elif version == "2.0":
+ config = {
+ "model_id": "diffusers-internal-dev/new-ltx-model",
+ "diffusers_config": {
+ "caption_channels": 3840,
+ "text_proj_in_factor": 49,
+ "video_connector_num_attention_heads": 30,
+ "video_connector_attention_head_dim": 128,
+ "video_connector_num_layers": 2,
+ "video_connector_num_learnable_registers": 128,
+ "audio_connector_num_attention_heads": 30,
+ "audio_connector_attention_head_dim": 128,
+ "audio_connector_num_layers": 2,
+ "audio_connector_num_learnable_registers": 128,
+ "connector_rope_base_seq_len": 4096,
+ "rope_theta": 10000.0,
+ "rope_double_precision": True,
+ "causal_temporal_positioning": False,
+ "rope_type": "split",
+ },
+ }
+
+ rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
+ special_keys_remap = {}
+
+ return config, rename_dict, special_keys_remap
+
+
+def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
+ config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
+ diffusers_config = config["diffusers_config"]
+
+ transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
+
+ with init_empty_weights():
+ transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(transformer_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in rename_dict.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(transformer_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(transformer_state_dict.keys()):
+ for special_key, handler_fn_inplace in special_keys_remap.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, transformer_state_dict)
+
+ transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
+ config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
+ diffusers_config = config["diffusers_config"]
+
+ _, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
+ if len(connector_state_dict) == 0:
+ raise ValueError("No connector weights found in the provided state dict.")
+
+ with init_empty_weights():
+ connectors = LTX2TextConnectors.from_config(diffusers_config)
+
+ for key in list(connector_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in rename_dict.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(connector_state_dict, key, new_key)
+
+ for key in list(connector_state_dict.keys()):
+ for special_key, handler_fn_inplace in special_keys_remap.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, connector_state_dict)
+
+ connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
+ return connectors
+
+
+def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
+ if version == "test":
+ config = {
+ "model_id": "diffusers-internal-dev/dummy-ltx2",
+ "diffusers_config": {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 128,
+ "block_out_channels": (256, 512, 1024, 2048),
+ "down_block_types": (
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ ),
+ "decoder_block_out_channels": (256, 512, 1024),
+ "layers_per_block": (4, 6, 6, 2, 2),
+ "decoder_layers_per_block": (5, 5, 5, 5),
+ "spatio_temporal_scaling": (True, True, True, True),
+ "decoder_spatio_temporal_scaling": (True, True, True),
+ "decoder_inject_noise": (False, False, False, False),
+ "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
+ "upsample_residual": (True, True, True),
+ "upsample_factor": (2, 2, 2),
+ "timestep_conditioning": False,
+ "patch_size": 4,
+ "patch_size_t": 1,
+ "resnet_norm_eps": 1e-6,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ "encoder_spatial_padding_mode": "zeros",
+ "decoder_spatial_padding_mode": "reflect",
+ "spatial_compression_ratio": 32,
+ "temporal_compression_ratio": 8,
+ },
+ }
+ rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
+ special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
+ elif version == "2.0":
+ config = {
+ "model_id": "diffusers-internal-dev/dummy-ltx2",
+ "diffusers_config": {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 128,
+ "block_out_channels": (256, 512, 1024, 2048),
+ "down_block_types": (
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ ),
+ "decoder_block_out_channels": (256, 512, 1024),
+ "layers_per_block": (4, 6, 6, 2, 2),
+ "decoder_layers_per_block": (5, 5, 5, 5),
+ "spatio_temporal_scaling": (True, True, True, True),
+ "decoder_spatio_temporal_scaling": (True, True, True),
+ "decoder_inject_noise": (False, False, False, False),
+ "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
+ "upsample_residual": (True, True, True),
+ "upsample_factor": (2, 2, 2),
+ "timestep_conditioning": False,
+ "patch_size": 4,
+ "patch_size_t": 1,
+ "resnet_norm_eps": 1e-6,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ "encoder_spatial_padding_mode": "zeros",
+ "decoder_spatial_padding_mode": "reflect",
+ "spatial_compression_ratio": 32,
+ "temporal_compression_ratio": 8,
+ },
+ }
+ rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
+ special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
+ return config, rename_dict, special_keys_remap
+
+
+def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
+ config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
+ diffusers_config = config["diffusers_config"]
+
+ with init_empty_weights():
+ vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in rename_dict.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in special_keys_remap.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
+ if version == "2.0":
+ config = {
+ "model_id": "diffusers-internal-dev/new-ltx-model",
+ "diffusers_config": {
+ "base_channels": 128,
+ "output_channels": 2,
+ "ch_mult": (1, 2, 4),
+ "num_res_blocks": 2,
+ "attn_resolutions": None,
+ "in_channels": 2,
+ "resolution": 256,
+ "latent_channels": 8,
+ "norm_type": "pixel",
+ "causality_axis": "height",
+ "dropout": 0.0,
+ "mid_block_add_attention": False,
+ "sample_rate": 16000,
+ "mel_hop_length": 160,
+ "is_causal": True,
+ "mel_bins": 64,
+ "double_z": True,
+ },
+ }
+ rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
+ special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
+ return config, rename_dict, special_keys_remap
+
+
+def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
+ config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
+ diffusers_config = config["diffusers_config"]
+
+ with init_empty_weights():
+ vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in rename_dict.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in special_keys_remap.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
+ if version == "2.0":
+ config = {
+ "model_id": "diffusers-internal-dev/new-ltx-model",
+ "diffusers_config": {
+ "in_channels": 128,
+ "hidden_channels": 1024,
+ "out_channels": 2,
+ "upsample_kernel_sizes": [16, 15, 8, 4, 4],
+ "upsample_factors": [6, 5, 2, 2, 2],
+ "resnet_kernel_sizes": [3, 7, 11],
+ "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "leaky_relu_negative_slope": 0.1,
+ "output_sampling_rate": 24000,
+ },
+ }
+ rename_dict = LTX_2_0_VOCODER_RENAME_DICT
+ special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
+ return config, rename_dict, special_keys_remap
+
+
+def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
+ config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
+ diffusers_config = config["diffusers_config"]
+
+ with init_empty_weights():
+ vocoder = LTX2Vocoder.from_config(diffusers_config)
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in rename_dict.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in special_keys_remap.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
+ return vocoder
+
+
+def get_ltx2_spatial_latent_upsampler_config(version: str):
+ if version == "2.0":
+ config = {
+ "in_channels": 128,
+ "mid_channels": 1024,
+ "num_blocks_per_stage": 4,
+ "dims": 3,
+ "spatial_upsample": True,
+ "temporal_upsample": False,
+ "rational_spatial_scale": 2.0,
+ }
+ else:
+ raise ValueError(f"Unsupported version: {version}")
+ return config
+
+
+def convert_ltx2_spatial_latent_upsampler(
+ original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
+):
+ with init_empty_weights():
+ latent_upsampler = LTX2LatentUpsamplerModel(**config)
+
+ latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
+ latent_upsampler.to(dtype)
+ return latent_upsampler
+
+
+def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
+ if repo_id is None and filename is None:
+ raise ValueError("Please supply at least one of `repo_id` or `filename`")
+
+ if repo_id is not None:
+ if filename is None:
+ raise ValueError("If repo_id is specified, filename must also be specified.")
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
+ else:
+ ckpt_path = filename
+
+ _, ext = os.path.splitext(ckpt_path)
+ if ext in [".safetensors", ".sft"]:
+ state_dict = safetensors.torch.load_file(ckpt_path)
+ else:
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+
+ return state_dict
+
+
+def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
+ # Ensure that the key prefix ends with a dot (.)
+ if not prefix.endswith("."):
+ prefix = prefix + "."
+
+ model_state_dict = {}
+ for param_name, param in combined_ckpt.items():
+ if param_name.startswith(prefix):
+ model_state_dict[param_name.replace(prefix, "")] = param
+
+ if prefix == "model.diffusion_model.":
+ # Some checkpoints store the text connector projection outside the diffusion model prefix.
+ connector_key = "text_embedding_projection.aggregate_embed.weight"
+ if connector_key in combined_ckpt and connector_key not in model_state_dict:
+ model_state_dict[connector_key] = combined_ckpt[connector_key]
+
+ return model_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--original_state_dict_repo_id",
+ default="Lightricks/LTX-2",
+ type=str,
+ help="HF Hub repo id with LTX 2.0 checkpoint",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ default=None,
+ type=str,
+ help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="2.0",
+ choices=["test", "2.0"],
+ help="Version of the LTX 2.0 model",
+ )
+
+ parser.add_argument(
+ "--combined_filename",
+ default="ltx-2-19b-dev.safetensors",
+ type=str,
+ help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
+ )
+ parser.add_argument("--vae_prefix", default="vae.", type=str)
+ parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
+ parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
+ parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
+
+ parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
+ parser.add_argument(
+ "--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
+ )
+ parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
+ parser.add_argument(
+ "--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
+ )
+ parser.add_argument(
+ "--text_encoder_model_id",
+ default="google/gemma-3-12b-it-qat-q4_0-unquantized",
+ type=str,
+ help="HF Hub id for the LTX 2.0 base text encoder model",
+ )
+ parser.add_argument(
+ "--tokenizer_id",
+ default="google/gemma-3-12b-it-qat-q4_0-unquantized",
+ type=str,
+ help="HF Hub id for the LTX 2.0 text tokenizer",
+ )
+ parser.add_argument(
+ "--latent_upsampler_filename",
+ default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
+ type=str,
+ help="Latent upsampler filename",
+ )
+
+ parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
+ parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
+ parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
+ parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
+ parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
+ parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
+ parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
+ parser.add_argument(
+ "--full_pipeline",
+ action="store_true",
+ help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
+ )
+ parser.add_argument(
+ "--upsample_pipeline",
+ action="store_true",
+ help="Whether to save a latent upsampling pipeline",
+ )
+
+ parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
+ parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
+ parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
+ parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
+ parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
+
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+VARIANT_MAPPING = {
+ "fp32": None,
+ "fp16": "fp16",
+ "bf16": "bf16",
+}
+
+
+def main(args):
+ vae_dtype = DTYPE_MAPPING[args.vae_dtype]
+ audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
+ dit_dtype = DTYPE_MAPPING[args.dit_dtype]
+ vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
+ text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
+
+ combined_ckpt = None
+ load_combined_models = any(
+ [
+ args.vae,
+ args.audio_vae,
+ args.dit,
+ args.vocoder,
+ args.text_encoder,
+ args.full_pipeline,
+ args.upsample_pipeline,
+ ]
+ )
+ if args.combined_filename is not None and load_combined_models:
+ combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
+
+ if args.vae or args.full_pipeline or args.upsample_pipeline:
+ if args.vae_filename is not None:
+ original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
+ elif combined_ckpt is not None:
+ original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
+ vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
+ if not args.full_pipeline and not args.upsample_pipeline:
+ vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
+
+ if args.audio_vae or args.full_pipeline:
+ if args.audio_vae_filename is not None:
+ original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
+ elif combined_ckpt is not None:
+ original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
+ audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
+ if not args.full_pipeline:
+ audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
+
+ if args.dit or args.full_pipeline:
+ if args.dit_filename is not None:
+ original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
+ elif combined_ckpt is not None:
+ original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
+ transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
+ if not args.full_pipeline:
+ transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
+
+ if args.connectors or args.full_pipeline:
+ if args.dit_filename is not None:
+ original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
+ elif combined_ckpt is not None:
+ original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
+ connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
+ if not args.full_pipeline:
+ connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
+
+ if args.vocoder or args.full_pipeline:
+ if args.vocoder_filename is not None:
+ original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
+ elif combined_ckpt is not None:
+ original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
+ vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
+ if not args.full_pipeline:
+ vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
+
+ if args.text_encoder or args.full_pipeline:
+ # text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
+ text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
+ if not args.full_pipeline:
+ text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
+
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
+ if not args.full_pipeline:
+ tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
+
+ if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
+ original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
+ repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
+ )
+ latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
+ latent_upsampler = convert_ltx2_spatial_latent_upsampler(
+ original_latent_upsampler_ckpt,
+ latent_upsampler_config,
+ dtype=vae_dtype,
+ )
+ if not args.full_pipeline and not args.upsample_pipeline:
+ latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
+
+ if args.full_pipeline:
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ use_dynamic_shifting=True,
+ base_shift=0.95,
+ max_shift=2.05,
+ base_image_seq_len=1024,
+ max_image_seq_len=4096,
+ shift_terminal=0.1,
+ )
+
+ pipe = LTX2Pipeline(
+ scheduler=scheduler,
+ vae=vae,
+ audio_vae=audio_vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ connectors=connectors,
+ transformer=transformer,
+ vocoder=vocoder,
+ )
+
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+
+ if args.upsample_pipeline:
+ pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
+
+ # Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
+ pipe.save_pretrained(
+ os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
+ )
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/setup.py b/setup.py
index c47124554479..d52d37787fdb 100644
--- a/setup.py
+++ b/setup.py
@@ -274,7 +274,7 @@ def run(self):
setup(
name="diffusers",
- version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index aa11a741af38..24b9c12db6d4 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -23,6 +23,7 @@
is_torchao_available,
is_torchsde_available,
is_transformers_available,
+ is_transformers_version,
)
@@ -193,6 +194,8 @@
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
+ "AutoencoderKLLTX2Audio",
+ "AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
@@ -223,6 +226,7 @@
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
+ "GlmImageTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -236,6 +240,7 @@
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LongCatImageTransformer2DModel",
+ "LTX2VideoTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
@@ -353,6 +358,7 @@
"KDPM2AncestralDiscreteScheduler",
"KDPM2DiscreteScheduler",
"LCMScheduler",
+ "LTXEulerAncestralRFScheduler",
"PNDMScheduler",
"RePaintScheduler",
"SASolverScheduler",
@@ -417,6 +423,8 @@
"QwenImageEditModularPipeline",
"QwenImageEditPlusAutoBlocks",
"QwenImageEditPlusModularPipeline",
+ "QwenImageLayeredAutoBlocks",
+ "QwenImageLayeredModularPipeline",
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
@@ -449,9 +457,11 @@
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
+ "BriaFiboEditPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
+ "ChromaInpaintPipeline",
"ChromaPipeline",
"ChronoEditPipeline",
"CLIPImageProjection",
@@ -472,6 +482,7 @@
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
+ "Flux2KleinPipeline",
"Flux2Pipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
@@ -486,6 +497,7 @@
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
+ "GlmImagePipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -537,7 +549,11 @@
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
+ "LTX2ImageToVideoPipeline",
+ "LTX2LatentUpsamplePipeline",
+ "LTX2Pipeline",
"LTXConditionPipeline",
+ "LTXI2VLongMultiPromptPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
@@ -937,6 +953,8 @@
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
+ AutoencoderKLLTX2Audio,
+ AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -967,6 +985,7 @@
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
+ GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -980,6 +999,7 @@
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
+ LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
@@ -1088,6 +1108,7 @@
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LCMScheduler,
+ LTXEulerAncestralRFScheduler,
PNDMScheduler,
RePaintScheduler,
SASolverScheduler,
@@ -1135,6 +1156,8 @@
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
+ QwenImageLayeredAutoBlocks,
+ QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
@@ -1163,9 +1186,11 @@
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
+ BriaFiboEditPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
+ ChromaInpaintPipeline,
ChromaPipeline,
ChronoEditPipeline,
CLIPImageProjection,
@@ -1186,6 +1211,7 @@
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
+ Flux2KleinPipeline,
Flux2Pipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
@@ -1200,6 +1226,7 @@
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
+ GlmImagePipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
@@ -1251,7 +1278,11 @@
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
+ LTX2ImageToVideoPipeline,
+ LTX2LatentUpsamplePipeline,
+ LTX2Pipeline,
LTXConditionPipeline,
+ LTXI2VLongMultiPromptPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index ace4e8543a1c..bdd4dbbcd4b5 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -67,6 +67,7 @@ def text_encoder_attn_modules(text_encoder):
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
+ "LTX2LoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
@@ -121,6 +122,7 @@ def text_encoder_attn_modules(text_encoder):
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
+ LTX2LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 2e87f757c352..8f7309d4ed1e 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -2140,6 +2140,54 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
return converted_state_dict
+def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
+ # Remove the prefix
+ state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
+ converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
+
+ if non_diffusers_prefix == "diffusion_model":
+ rename_dict = {
+ "patchify_proj": "proj_in",
+ "audio_patchify_proj": "audio_proj_in",
+ "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
+ "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
+ "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
+ "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
+ "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
+ "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ }
+ else:
+ rename_dict = {"aggregate_embed": "text_proj_in"}
+
+ # Apply renaming
+ renamed_state_dict = {}
+ for key, value in converted_state_dict.items():
+ new_key = key[:]
+ for old_pattern, new_pattern in rename_dict.items():
+ new_key = new_key.replace(old_pattern, new_pattern)
+ renamed_state_dict[new_key] = value
+
+ # Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
+ final_state_dict = {}
+ for key, value in renamed_state_dict.items():
+ if key.startswith("adaln_single."):
+ new_key = key.replace("adaln_single.", "time_embed.")
+ final_state_dict[new_key] = value
+ elif key.startswith("audio_adaln_single."):
+ new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
+ final_state_dict[new_key] = value
+ else:
+ final_state_dict[key] = value
+
+ # Add transformer prefix
+ prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
+ final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}
+
+ return final_state_dict
+
+
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 03a2fe9f3f8e..24d1fd7b9308 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -48,6 +48,7 @@
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
+ _convert_non_diffusers_ltx2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
@@ -74,6 +75,7 @@
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
+LTX2_CONNECTOR_NAME = "connectors"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
@@ -212,7 +214,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_unet(
state_dict,
@@ -639,7 +641,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_unet(
state_dict,
@@ -1079,7 +1081,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -1375,7 +1377,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -1657,7 +1659,7 @@ def load_lora_weights(
)
if not (has_lora_keys or has_norm_keys):
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
transformer_lora_state_dict = {
k: state_dict.get(k)
@@ -2504,7 +2506,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -2701,7 +2703,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -2904,7 +2906,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3011,6 +3013,233 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)
+class LTX2LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer", "connectors"]
+ transformer_name = TRANSFORMER_NAME
+ connectors_name = LTX2_CONNECTOR_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ final_state_dict = state_dict
+ is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
+ has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
+ if is_non_diffusers_format:
+ final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
+ if has_connector:
+ connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
+ state_dict, "text_embedding_projection"
+ )
+ final_state_dict.update(connectors_state_dict)
+ out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
+ return out
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
+
+ transformer_peft_state_dict = {
+ k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
+ }
+ connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
+ self.load_lora_into_transformer(
+ transformer_peft_state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ if connectors_peft_state_dict:
+ self.load_lora_into_transformer(
+ connectors_peft_state_dict,
+ transformer=getattr(self, self.connectors_name)
+ if not hasattr(self, "connectors")
+ else self.connectors,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ prefix=self.connectors_name,
+ )
+
+ @classmethod
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ prefix: str = "transformer",
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {prefix}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ prefix=prefix,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
class SanaLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
@@ -3104,7 +3333,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3307,7 +3536,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3511,7 +3740,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3711,7 +3940,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3965,7 +4194,7 @@ def load_lora_weights(
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4242,7 +4471,7 @@ def load_lora_weights(
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4462,7 +4691,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -4665,7 +4894,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -4871,7 +5100,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5077,7 +5306,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5280,7 +5509,7 @@ def load_lora_weights(
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 30a78f00b3f2..16f1a5d1ec7e 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -63,9 +63,12 @@
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
+ "ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
+ "LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
+ "LTX2TextConnectors": lambda model_cls, weights: weights,
}
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index f49eca3d3688..c733bd489da3 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -40,6 +40,9 @@
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
+ convert_ltx2_audio_vae_to_diffusers,
+ convert_ltx2_transformer_to_diffusers,
+ convert_ltx2_vae_to_diffusers,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
@@ -176,6 +179,18 @@
"ZImageControlNetModel": {
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
},
+ "LTX2VideoTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "AutoencoderKLLTX2Video": {
+ "checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
+ "default_subfolder": "vae",
+ },
+ "AutoencoderKLLTX2Audio": {
+ "checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
+ "default_subfolder": "audio_vae",
+ },
}
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index 9f56a27a174e..26f6c4388d19 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -112,7 +112,8 @@
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
- "vae.per_channel_statistics.mean-of-means",
+ "vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
+ "vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
@@ -147,6 +148,11 @@
"net.pos_embedder.dim_spatial_range",
],
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
+ "ltx2": [
+ "model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
+ "vae.per_channel_statistics.mean-of-means",
+ "audio_vae.per_channel_statistics.mean-of-means",
+ ],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -226,7 +232,9 @@
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
- "z-image-turbo-controlnet-2.x": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
+ "z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
+ "z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
+ "ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
}
# Use to configure model sample size when original config is provided
@@ -784,11 +792,20 @@ def infer_diffusers_model_type(checkpoint):
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
- model_type = "z-image-turbo-controlnet-2.x"
+ before_proj_weight = checkpoint.get("control_noise_refiner.0.before_proj.weight", None)
+ if before_proj_weight is None:
+ model_type = "z-image-turbo-controlnet-2.0"
+ elif before_proj_weight is not None and torch.all(before_proj_weight == 0.0):
+ model_type = "z-image-turbo-controlnet-2.0"
+ else:
+ model_type = "z-image-turbo-controlnet-2.1"
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
model_type = "z-image-turbo-controlnet"
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
+ model_type = "ltx2-dev"
+
else:
model_type = "v1"
@@ -3913,3 +3930,161 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa
return converted_state_dict
else:
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
+
+
+def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
+ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
+ # Transformer prefix
+ "model.diffusion_model.": "",
+ # Input Patchify Projections
+ "patchify_proj": "proj_in",
+ "audio_patchify_proj": "audio_proj_in",
+ # Modulation Parameters
+ # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
+ # substrings of the other modulation parameters below
+ "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
+ "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
+ "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
+ "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
+ # Transformer Blocks
+ # Per-Block Cross Attention Modulation Parameters
+ "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
+ "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
+ # Attention QK Norms
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ }
+
+ def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ def remove_keys_inplace(key: str, state_dict) -> None:
+ state_dict.pop(key)
+
+ def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
+ # Skip if not a weight, bias
+ if ".weight" not in key and ".bias" not in key:
+ return
+
+ if key.startswith("adaln_single."):
+ new_key = key.replace("adaln_single.", "time_embed.")
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+
+ if key.startswith("audio_adaln_single."):
+ new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+
+ return
+
+ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "video_embeddings_connector": remove_keys_inplace,
+ "audio_embeddings_connector": remove_keys_inplace,
+ "adaln_single": convert_ltx2_transformer_adaln_single,
+ }
+
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ update_state_dict_inplace(converted_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
+ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
+ # Video VAE prefix
+ "vae.": "",
+ # Encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.1",
+ "down_blocks.3": "down_blocks.1.downsamplers.0",
+ "down_blocks.4": "down_blocks.2",
+ "down_blocks.5": "down_blocks.2.downsamplers.0",
+ "down_blocks.6": "down_blocks.3",
+ "down_blocks.7": "down_blocks.3.downsamplers.0",
+ "down_blocks.8": "mid_block",
+ # Decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ # Common
+ # For all 3D ResNets
+ "res_blocks": "resnets",
+ "per_channel_statistics.mean-of-means": "latents_mean",
+ "per_channel_statistics.std-of-means": "latents_std",
+ }
+
+ def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ def remove_keys_inplace(key: str, state_dict) -> None:
+ state_dict.pop(key)
+
+ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
+ "per_channel_statistics.channel": remove_keys_inplace,
+ "per_channel_statistics.mean-of-stds": remove_keys_inplace,
+ }
+
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ update_state_dict_inplace(converted_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
+ LTX_2_0_AUDIO_VAE_RENAME_DICT = {
+ # Audio VAE prefix
+ "audio_vae.": "",
+ "per_channel_statistics.mean-of-means": "latents_mean",
+ "per_channel_statistics.std-of-means": "latents_std",
+ }
+
+ def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ update_state_dict_inplace(converted_state_dict, key, new_key)
+
+ return converted_state_dict
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index c4664f00cad2..4d1db36a7352 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -41,6 +41,8 @@
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
+ _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
+ _import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -96,6 +98,7 @@
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
+ _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -104,6 +107,7 @@
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
+ _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
@@ -153,6 +157,8 @@
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
+ AutoencoderKLLTX2Audio,
+ AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -203,6 +209,7 @@
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
+ GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
@@ -212,6 +219,7 @@
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
+ LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py
index 2a4eb520c796..1c7703a13c52 100644
--- a/src/diffusers/models/_modeling_parallel.py
+++ b/src/diffusers/models/_modeling_parallel.py
@@ -90,10 +90,6 @@ def __post_init__(self):
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
- if self.ring_degree > 1 and self.ulysses_degree > 1:
- raise ValueError(
- "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
- )
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index 310c44457c27..61c478b03c4f 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -235,6 +235,10 @@ def decorator(func):
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]
+ @classmethod
+ def set_active_backend(cls, backend: str):
+ cls._active_backend = backend
+
@classmethod
def list_backends(cls):
return list(cls._backends.keys())
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
_maybe_download_kernel_for_backend(backend)
old_backend = _AttentionBackendRegistry._active_backend
- _AttentionBackendRegistry._active_backend = backend
+ _AttentionBackendRegistry.set_active_backend(backend)
try:
yield
finally:
- _AttentionBackendRegistry._active_backend = old_backend
+ _AttentionBackendRegistry.set_active_backend(old_backend)
def dispatch_attention_fn(
@@ -348,6 +352,7 @@ def dispatch_attention_fn(
check(**kwargs)
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
+
return backend_fn(**kwargs)
@@ -1106,6 +1111,51 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
+def _npu_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if return_lse:
+ raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
+
+ out = npu_fusion_attention(
+ query,
+ key,
+ value,
+ query.size(2), # num_heads
+ input_layout="BSND",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0 - dropout_p,
+ sync=False,
+ inner_precise=0,
+ )[0]
+
+ return out
+
+
+# Not implemented yet.
+def _npu_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
+
+
# ===== Context parallel =====
@@ -1132,6 +1182,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
return x
+def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
+ """
+ Perform dimension sharding / reassembly across processes using _all_to_all_single.
+
+ This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
+ head dimension flexibly by accepting scatter_idx and gather_idx.
+
+ Args:
+ x (torch.Tensor):
+ Input tensor. Expected shapes:
+ - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
+ - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
+ scatter_idx (int) :
+ Dimension along which the tensor is partitioned before all-to-all.
+ gather_idx (int):
+ Dimension along which the output is reassembled after all-to-all.
+ group :
+ Distributed process group for the Ulysses group.
+
+ Returns:
+ torch.Tensor: Tensor with globally exchanged dimensions.
+ - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
+ - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
+ """
+ group_world_size = torch.distributed.get_world_size(group)
+
+ if scatter_idx == 2 and gather_idx == 1:
+ # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
+ # dimension and scatters head dimension
+ batch_size, seq_len_local, num_heads, head_dim = x.shape
+ seq_len = seq_len_local * group_world_size
+ num_heads_local = num_heads // group_world_size
+
+ # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
+ x_temp = (
+ x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
+ .transpose(0, 2)
+ .contiguous()
+ )
+
+ if group_world_size > 1:
+ out = _all_to_all_single(x_temp, group=group)
+ else:
+ out = x_temp
+ # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
+ out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
+ out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
+ return out
+ elif scatter_idx == 1 and gather_idx == 2:
+ # Used after ulysses sequence parallel in unified SP. gathers the head dimension
+ # scatters back the sequence dimension.
+ batch_size, seq_len, num_heads_local, head_dim = x.shape
+ num_heads = num_heads_local * group_world_size
+ seq_len_local = seq_len // group_world_size
+
+ # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
+ x_temp = (
+ x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
+ .permute(1, 3, 2, 0, 4)
+ .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
+ )
+
+ if group_world_size > 1:
+ output = _all_to_all_single(x_temp, group)
+ else:
+ output = x_temp
+ output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
+ output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
+ return output
+ else:
+ raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
+
+
+class SeqAllToAllDim(torch.autograd.Function):
+ """
+ all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
+ for more info.
+ """
+
+ @staticmethod
+ def forward(ctx, group, input, scatter_id=2, gather_id=1):
+ ctx.group = group
+ ctx.scatter_id = scatter_id
+ ctx.gather_id = gather_id
+ return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
+
+ @staticmethod
+ def backward(ctx, grad_outputs):
+ grad_input = SeqAllToAllDim.apply(
+ ctx.group,
+ grad_outputs,
+ ctx.gather_id, # reversed
+ ctx.scatter_id, # reversed
+ )
+ return (None, grad_input, None, None)
+
+
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
@@ -1192,7 +1339,10 @@ def forward(
out = out.to(torch.float32)
lse = lse.to(torch.float32)
- lse = lse.unsqueeze(-1)
+ # Refer to:
+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+ if is_torch_version("<", "2.9.0"):
+ lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1253,7 +1403,7 @@ def backward(
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
- return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1348,7 +1498,69 @@ def backward(
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
)
- return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
+
+
+def _templated_unified_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ scatter_idx: int = 2,
+ gather_idx: int = 1,
+):
+ """
+ Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
+ """
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ ulysses_group = ulysses_mesh.get_group()
+
+ query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
+ key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
+ value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
+ out = TemplatedRingAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ if return_lse:
+ context_layer, lse, *_ = out
+ else:
+ context_layer = out
+ # context_layer is of shape (B, S, H_LOCAL, D)
+ output = SeqAllToAllDim.apply(
+ ulysses_group,
+ context_layer,
+ gather_idx,
+ scatter_idx,
+ )
+ if return_lse:
+ # lse is of shape (B, S, H_LOCAL, 1)
+ # Refer to:
+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+ if is_torch_version("<", "2.9.0"):
+ lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
+ lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
+ lse = lse.squeeze(-1)
+ return (output, lse)
+ return output
def _templated_context_parallel_attention(
@@ -1366,15 +1578,31 @@ def _templated_context_parallel_attention(
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
- if attn_mask is not None:
- raise ValueError("Attention mask is not yet supported for templated attention.")
if is_causal:
raise ValueError("Causal attention is not yet supported for templated attention.")
if enable_gqa:
raise ValueError("GQA is not yet supported for templated attention.")
# TODO: add support for unified attention with ring/ulysses degree both being > 1
- if _parallel_config.context_parallel_config.ring_degree > 1:
+ if (
+ _parallel_config.context_parallel_config.ring_degree > 1
+ and _parallel_config.context_parallel_config.ulysses_degree > 1
+ ):
+ return _templated_unified_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ elif _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
@@ -1420,6 +1648,7 @@ def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1427,6 +1656,9 @@ def _flash_attention(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for flash-attn 2.")
+
if _parallel_config is None:
out = flash_attn_func(
q=query,
@@ -1469,6 +1701,7 @@ def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1476,6 +1709,9 @@ def _flash_attention_hub(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for flash-attn 2.")
+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
@@ -1612,11 +1848,15 @@ def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for flash-attn 3.")
+
out, lse = _wrapped_flash_attn_3(
q=query,
k=key,
@@ -1636,6 +1876,7 @@ def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
@@ -1646,6 +1887,8 @@ def _flash_attention_3_hub(
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
@@ -1785,12 +2028,16 @@ def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for aiter attention")
+
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
@@ -1881,6 +2128,43 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
return out
+def _prepare_additive_attn_mask(
+ attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
+) -> torch.Tensor:
+ """
+ Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
+
+ This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
+
+ Args:
+ attn_mask: 2D tensor [batch_size, seq_len_k]
+ - Boolean: True means attend, False means mask out
+ - Additive: 0.0 means attend, -inf means mask out
+ target_dtype: The dtype to convert the mask to (usually query.dtype)
+ reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
+
+ Returns:
+ Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
+ reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
+ """
+ # Check if the mask is boolean or already additive
+ if attn_mask.dtype == torch.bool:
+ # Convert boolean to additive: True -> 0.0, False -> -inf
+ attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
+ # Convert to target dtype
+ attn_mask = attn_mask.to(dtype=target_dtype)
+ else:
+ # Already additive mask - just ensure correct dtype
+ attn_mask = attn_mask.to(dtype=target_dtype)
+
+ # Optionally reshape to 4D for broadcasting in attention mechanisms
+ if reshape_4d:
+ batch_size, seq_len_k = attn_mask.shape
+ attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
+
+ return attn_mask
+
+
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
@@ -1900,6 +2184,19 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
+
+ # Reshape 2D mask to 4D for SDPA
+ # SDPA accepts both boolean masks (torch.bool) and additive masks (float)
+ if (
+ attn_mask is not None
+ and attn_mask.ndim == 2
+ and attn_mask.shape[0] == query.shape[0]
+ and attn_mask.shape[1] == key.shape[1]
+ ):
+ # Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
+ # SDPA handles both boolean and additive masks correctly
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
+
if _parallel_config is None:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
@@ -2028,6 +2325,7 @@ def _native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -2035,6 +2333,9 @@ def _native_flash_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for aiter attention")
+
lse = None
if _parallel_config is None and not return_lse:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2108,34 +2409,52 @@ def _native_math_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NPU,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
)
def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
- query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
- out = npu_fusion_attention(
- query,
- key,
- value,
- query.size(1), # num_heads
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0 - dropout_p,
- sync=False,
- inner_precise=0,
- )[0]
- out = out.transpose(1, 2).contiguous()
+ if _parallel_config is None:
+ out = npu_fusion_attention(
+ query,
+ key,
+ value,
+ query.size(2), # num_heads
+ input_layout="BSND",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0 - dropout_p,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ dropout_p,
+ None,
+ scale,
+ None,
+ return_lse,
+ forward_op=_npu_attention_forward_op,
+ backward_op=_npu_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
return out
@@ -2148,10 +2467,13 @@ def _native_xla_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for XLA attention")
if return_lse:
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2175,11 +2497,14 @@ def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
if _parallel_config is None:
out = sageattn(
@@ -2223,11 +2548,14 @@ def _sage_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
if _parallel_config is None:
@@ -2309,11 +2637,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
@@ -2333,11 +2664,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
@@ -2357,11 +2691,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
@@ -2381,11 +2718,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
@@ -2423,10 +2763,34 @@ def _xformers_attention(
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
- attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+ # Convert 2D mask to 4D for xformers
+ # Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
+ # xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
+ # Need memory alignment - create larger tensor and slice for alignment
+ original_seq_len = attn_mask.size(1)
+ aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
+
+ # Create aligned 4D tensor and slice to ensure proper memory layout
+ aligned_mask = torch.zeros(
+ (batch_size, num_heads_q, seq_len_q, aligned_seq_len),
+ dtype=query.dtype,
+ device=query.device,
+ )
+ # Convert to 4D additive mask (handles both boolean and additive inputs)
+ mask_additive = _prepare_additive_attn_mask(
+ attn_mask, target_dtype=query.dtype
+ ) # [batch, 1, 1, seq_len_k]
+ # Broadcast to [batch, heads, seq_q, seq_len_k]
+ aligned_mask[:, :, :, :original_seq_len] = mask_additive
+ # Mask out the padding (already -inf from zeros -> where with default)
+ aligned_mask[:, :, :, original_seq_len:] = float("-inf")
+
+ # Slice to actual size with proper alignment
+ attn_mask = aligned_mask[:, :, :, :seq_len_kv]
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
- attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
+ elif attn_mask.ndim == 4:
+ attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
if enable_gqa:
if num_heads_q % num_heads_kv != 0:
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index 56df27f93cd7..8e7a9c81d2ad 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -10,6 +10,8 @@
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
+from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
+from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py
index ec301ef8ad51..19b666fdc4a8 100644
--- a/src/diffusers/models/autoencoders/autoencoder_dc.py
+++ b/src/diffusers/models/autoencoders/autoencoder_dc.py
@@ -102,14 +102,14 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
- qkv_mutliscales: Tuple[int, ...] = (),
+ qkv_multiscales: Tuple[int, ...] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
- in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
+ in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales
)
else:
@@ -247,7 +247,7 @@ def __init__(
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
- qkv_mutliscales=qkv_multiscales[i],
+ qkv_multiscales=qkv_multiscales[i],
)
down_block_list.append(block)
@@ -339,7 +339,7 @@ def __init__(
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
- qkv_mutliscales=qkv_multiscales[i],
+ qkv_multiscales=qkv_multiscales[i],
)
up_block_list.append(block)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
index 616d0d415840..c02f11bef40a 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
@@ -27,7 +27,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -410,7 +410,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return h
-class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
@@ -486,27 +486,6 @@ def enable_tiling(
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py
index 2249063a9f00..973574e616bf 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py
@@ -26,7 +26,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,7 +584,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states
-class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
+class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
@@ -685,27 +685,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
index 4b1beb74a3bc..c662c1657513 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
@@ -26,7 +26,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -625,7 +625,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states
-class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
+class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanVideo-1.5.
@@ -723,27 +723,6 @@ def enable_tiling(
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
new file mode 100644
index 000000000000..01dd55a938b6
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
@@ -0,0 +1,1521 @@
+# Copyright 2025 The Lightricks team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
+
+
+class PerChannelRMSNorm(nn.Module):
+ """
+ Per-pixel (per-location) RMS normalization layer.
+
+ For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values
+ across that dimension:
+
+ y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
+ """
+
+ def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None:
+ """
+ Args:
+ dim: Dimension along which to compute the RMS (typically channels).
+ eps: Small constant added for numerical stability.
+ """
+ super().__init__()
+ self.channel_dim = channel_dim
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor:
+ """
+ Apply RMS normalization along the configured dimension.
+ """
+ channel_dim = channel_dim or self.channel_dim
+ # Compute mean of squared values along `dim`, keep dimensions for broadcasting.
+ mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True)
+ # Normalize by the root-mean-square (RMS).
+ rms = torch.sqrt(mean_sq + self.eps)
+ return x / rms
+
+
+# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime
+class LTX2VideoCausalConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ groups: int = 1,
+ spatial_padding_mode: str = "zeros",
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
+
+ dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
+ stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
+ height_pad = self.kernel_size[1] // 2
+ width_pad = self.kernel_size[2] // 2
+ padding = (0, height_pad, width_pad)
+
+ self.conv = nn.Conv3d(
+ in_channels,
+ out_channels,
+ self.kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ padding=padding,
+ padding_mode=spatial_padding_mode,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
+ time_kernel_size = self.kernel_size[0]
+
+ if causal:
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1))
+ hidden_states = torch.concatenate([pad_left, hidden_states], dim=2)
+ else:
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
+ pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
+ hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2)
+
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding
+# mode is configurable
+class LTX2VideoResnetBlock3d(nn.Module):
+ r"""
+ A 3D ResNet block used in the LTX 2.0 audiovisual model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ elementwise_affine (`bool`, defaults to `False`):
+ Whether to enable elementwise affinity in the normalization layers.
+ non_linearity (`str`, defaults to `"swish"`):
+ Activation function to use.
+ conv_shortcut (bool, defaults to `False`):
+ Whether or not to use a convolution shortcut.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ dropout: float = 0.0,
+ eps: float = 1e-6,
+ elementwise_affine: bool = False,
+ non_linearity: str = "swish",
+ inject_noise: bool = False,
+ timestep_conditioning: bool = False,
+ spatial_padding_mode: str = "zeros",
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.norm1 = PerChannelRMSNorm()
+ self.conv1 = LTX2VideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ self.norm2 = PerChannelRMSNorm()
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = LTX2VideoCausalConv3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ self.norm3 = None
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
+ # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d
+ self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
+
+ self.per_channel_scale1 = None
+ self.per_channel_scale2 = None
+ if inject_noise:
+ self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
+ self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
+
+ self.scale_shift_table = None
+ if timestep_conditioning:
+ self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ causal: bool = True,
+ ) -> torch.Tensor:
+ hidden_states = inputs
+
+ hidden_states = self.norm1(hidden_states)
+
+ if self.scale_shift_table is not None:
+ temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
+ shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
+ hidden_states = hidden_states * (1 + scale_1) + shift_1
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states, causal=causal)
+
+ if self.per_channel_scale1 is not None:
+ spatial_shape = hidden_states.shape[-2:]
+ spatial_noise = torch.randn(
+ spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
+ )[None]
+ hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
+
+ hidden_states = self.norm2(hidden_states)
+
+ if self.scale_shift_table is not None:
+ hidden_states = hidden_states * (1 + scale_2) + shift_2
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states, causal=causal)
+
+ if self.per_channel_scale2 is not None:
+ spatial_shape = hidden_states.shape[-2:]
+ spatial_noise = torch.randn(
+ spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
+ )[None]
+ hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
+
+ if self.norm3 is not None:
+ inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
+
+ if self.conv_shortcut is not None:
+ inputs = self.conv_shortcut(inputs)
+
+ hidden_states = hidden_states + inputs
+ return hidden_states
+
+
+# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
+class LTXVideoDownsampler3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ spatial_padding_mode: str = "zeros",
+ ) -> None:
+ super().__init__()
+
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
+ self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
+
+ out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
+
+ self.conv = LTX2VideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
+ hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
+
+ residual = (
+ hidden_states.unflatten(4, (-1, self.stride[2]))
+ .unflatten(3, (-1, self.stride[1]))
+ .unflatten(2, (-1, self.stride[0]))
+ )
+ residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
+ residual = residual.unflatten(1, (-1, self.group_size))
+ residual = residual.mean(dim=2)
+
+ hidden_states = self.conv(hidden_states, causal=causal)
+ hidden_states = (
+ hidden_states.unflatten(4, (-1, self.stride[2]))
+ .unflatten(3, (-1, self.stride[1]))
+ .unflatten(2, (-1, self.stride[0]))
+ )
+ hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
+class LTXVideoUpsampler3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ residual: bool = False,
+ upscale_factor: int = 1,
+ spatial_padding_mode: str = "zeros",
+ ) -> None:
+ super().__init__()
+
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
+ self.residual = residual
+ self.upscale_factor = upscale_factor
+
+ out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
+
+ self.conv = LTX2VideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ if self.residual:
+ residual = hidden_states.reshape(
+ batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
+ )
+ residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
+ residual = residual.repeat(1, repeats, 1, 1, 1)
+ residual = residual[:, :, self.stride[0] - 1 :]
+
+ hidden_states = self.conv(hidden_states, causal=causal)
+ hidden_states = hidden_states.reshape(
+ batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
+ )
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
+
+ if self.residual:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d
+class LTX2VideoDownBlock3D(nn.Module):
+ r"""
+ Down block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ spatio_temporal_scale (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ Whether or not to downsample across temporal dimension.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ spatio_temporal_scale: bool = True,
+ downsample_type: str = "conv",
+ spatial_padding_mode: str = "zeros",
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ LTX2VideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.downsamplers = None
+ if spatio_temporal_scale:
+ self.downsamplers = nn.ModuleList()
+
+ if downsample_type == "conv":
+ self.downsamplers.append(
+ LTX2VideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=(2, 2, 2),
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ )
+ elif downsample_type == "spatial":
+ self.downsamplers.append(
+ LTXVideoDownsampler3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=(1, 2, 2),
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ )
+ elif downsample_type == "temporal":
+ self.downsamplers.append(
+ LTXVideoDownsampler3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=(2, 1, 1),
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ )
+ elif downsample_type == "spatiotemporal":
+ self.downsamplers.append(
+ LTXVideoDownsampler3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=(2, 2, 2),
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ causal: bool = True,
+ ) -> torch.Tensor:
+ r"""Forward method of the `LTXDownBlock3D` class."""
+
+ for i, resnet in enumerate(self.resnets):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
+ else:
+ hidden_states = resnet(hidden_states, temb, generator, causal=causal)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, causal=causal)
+
+ return hidden_states
+
+
+# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
+# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d
+class LTX2VideoMidBlock3d(nn.Module):
+ r"""
+ A middle block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ inject_noise: bool = False,
+ timestep_conditioning: bool = False,
+ spatial_padding_mode: str = "zeros",
+ ) -> None:
+ super().__init__()
+
+ self.time_embedder = None
+ if timestep_conditioning:
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ LTX2VideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ inject_noise=inject_noise,
+ timestep_conditioning=timestep_conditioning,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ causal: bool = True,
+ ) -> torch.Tensor:
+ r"""Forward method of the `LTXMidBlock3D` class."""
+
+ if self.time_embedder is not None:
+ temb = self.time_embedder(
+ timestep=temb.flatten(),
+ resolution=None,
+ aspect_ratio=None,
+ batch_size=hidden_states.size(0),
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
+
+ for i, resnet in enumerate(self.resnets):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
+ else:
+ hidden_states = resnet(hidden_states, temb, generator, causal=causal)
+
+ return hidden_states
+
+
+# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d
+class LTX2VideoUpBlock3d(nn.Module):
+ r"""
+ Up block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ spatio_temporal_scale (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ Whether or not to downsample across temporal dimension.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ spatio_temporal_scale: bool = True,
+ inject_noise: bool = False,
+ timestep_conditioning: bool = False,
+ upsample_residual: bool = False,
+ upscale_factor: int = 1,
+ spatial_padding_mode: str = "zeros",
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ self.time_embedder = None
+ if timestep_conditioning:
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
+
+ self.conv_in = None
+ if in_channels != out_channels:
+ self.conv_in = LTX2VideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ inject_noise=inject_noise,
+ timestep_conditioning=timestep_conditioning,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ self.upsamplers = None
+ if spatio_temporal_scale:
+ self.upsamplers = nn.ModuleList(
+ [
+ LTXVideoUpsampler3d(
+ out_channels * upscale_factor,
+ stride=(2, 2, 2),
+ residual=upsample_residual,
+ upscale_factor=upscale_factor,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ ]
+ )
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ LTX2VideoResnetBlock3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ inject_noise=inject_noise,
+ timestep_conditioning=timestep_conditioning,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ causal: bool = True,
+ ) -> torch.Tensor:
+ if self.conv_in is not None:
+ hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal)
+
+ if self.time_embedder is not None:
+ temb = self.time_embedder(
+ timestep=temb.flatten(),
+ resolution=None,
+ aspect_ratio=None,
+ batch_size=hidden_states.size(0),
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, causal=causal)
+
+ for i, resnet in enumerate(self.resnets):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
+ else:
+ hidden_states = resnet(hidden_states, temb, generator, causal=causal)
+
+ return hidden_states
+
+
+# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is
+# different, as is the layers_per_block (the 2.0 VAE is bigger)
+class LTX2VideoEncoder3d(nn.Module):
+ r"""
+ The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
+ representation.
+
+ Args:
+ in_channels (`int`, defaults to 3):
+ Number of input channels.
+ out_channels (`int`, defaults to 128):
+ Number of latent channels.
+ block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`):
+ The number of output channels for each block.
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`:
+ Whether a block should contain spatio-temporal downscaling layers or not.
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`):
+ The number of layers per block.
+ downsample_type (`Tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`):
+ The spatiotemporal downsampling pattern per block. Per-layer values can be
+ - `"spatial"` (downsample spatial dims by 2x)
+ - `"temporal"` (downsample temporal dim by 2x)
+ - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x)
+ patch_size (`int`, defaults to `4`):
+ The size of spatial patches.
+ patch_size_t (`int`, defaults to `1`):
+ The size of temporal patches.
+ resnet_norm_eps (`float`, defaults to `1e-6`):
+ Epsilon value for ResNet normalization layers.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 128,
+ block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048),
+ down_block_types: Tuple[str, ...] = (
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ ),
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True),
+ layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2),
+ downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
+ patch_size: int = 4,
+ patch_size_t: int = 1,
+ resnet_norm_eps: float = 1e-6,
+ is_causal: bool = True,
+ spatial_padding_mode: str = "zeros",
+ ):
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.in_channels = in_channels * patch_size**2
+ self.is_causal = is_causal
+
+ output_channel = out_channels
+
+ self.conv_in = LTX2VideoCausalConv3d(
+ in_channels=self.in_channels,
+ out_channels=output_channel,
+ kernel_size=3,
+ stride=1,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ # down blocks
+ num_block_out_channels = len(block_out_channels)
+ self.down_blocks = nn.ModuleList([])
+ for i in range(num_block_out_channels):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+
+ if down_block_types[i] == "LTX2VideoDownBlock3D":
+ down_block = LTX2VideoDownBlock3D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ num_layers=layers_per_block[i],
+ resnet_eps=resnet_norm_eps,
+ spatio_temporal_scale=spatio_temporal_scaling[i],
+ downsample_type=downsample_type[i],
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ else:
+ raise ValueError(f"Unknown down block type: {down_block_types[i]}")
+
+ self.down_blocks.append(down_block)
+
+ # mid block
+ self.mid_block = LTX2VideoMidBlock3d(
+ in_channels=output_channel,
+ num_layers=layers_per_block[-1],
+ resnet_eps=resnet_norm_eps,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ # out
+ self.norm_out = PerChannelRMSNorm()
+ self.conv_act = nn.SiLU()
+ self.conv_out = LTX2VideoCausalConv3d(
+ in_channels=output_channel,
+ out_channels=out_channels + 1,
+ kernel_size=3,
+ stride=1,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor:
+ r"""The forward method of the `LTXVideoEncoder3d` class."""
+
+ p = self.patch_size
+ p_t = self.patch_size_t
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+ causal = causal or self.is_causal
+
+ hidden_states = hidden_states.reshape(
+ batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
+ )
+ # Thanks for driving me insane with the weird patching order :(
+ hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
+ hidden_states = self.conv_in(hidden_states, causal=causal)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for down_block in self.down_blocks:
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal)
+
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal)
+ else:
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states, causal=causal)
+
+ hidden_states = self.mid_block(hidden_states, causal=causal)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states, causal=causal)
+
+ last_channel = hidden_states[:, -1:]
+ last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
+ hidden_states = torch.cat([hidden_states, last_channel], dim=1)
+
+ return hidden_states
+
+
+# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2
+class LTX2VideoDecoder3d(nn.Module):
+ r"""
+ The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
+ sample.
+
+ Args:
+ in_channels (`int`, defaults to 128):
+ Number of latent channels.
+ out_channels (`int`, defaults to 3):
+ Number of output channels.
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
+ The number of output channels for each block.
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
+ Whether a block should contain spatio-temporal upscaling layers or not.
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
+ The number of layers per block.
+ patch_size (`int`, defaults to `4`):
+ The size of spatial patches.
+ patch_size_t (`int`, defaults to `1`):
+ The size of temporal patches.
+ resnet_norm_eps (`float`, defaults to `1e-6`):
+ Epsilon value for ResNet normalization layers.
+ is_causal (`bool`, defaults to `False`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ timestep_conditioning (`bool`, defaults to `False`):
+ Whether to condition the model on timesteps.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 128,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (256, 512, 1024),
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True),
+ layers_per_block: Tuple[int, ...] = (5, 5, 5, 5),
+ patch_size: int = 4,
+ patch_size_t: int = 1,
+ resnet_norm_eps: float = 1e-6,
+ is_causal: bool = False,
+ inject_noise: Tuple[bool, ...] = (False, False, False),
+ timestep_conditioning: bool = False,
+ upsample_residual: Tuple[bool, ...] = (True, True, True),
+ upsample_factor: Tuple[bool, ...] = (2, 2, 2),
+ spatial_padding_mode: str = "reflect",
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.out_channels = out_channels * patch_size**2
+ self.is_causal = is_causal
+
+ block_out_channels = tuple(reversed(block_out_channels))
+ spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
+ layers_per_block = tuple(reversed(layers_per_block))
+ inject_noise = tuple(reversed(inject_noise))
+ upsample_residual = tuple(reversed(upsample_residual))
+ upsample_factor = tuple(reversed(upsample_factor))
+ output_channel = block_out_channels[0]
+
+ self.conv_in = LTX2VideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=output_channel,
+ kernel_size=3,
+ stride=1,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ self.mid_block = LTX2VideoMidBlock3d(
+ in_channels=output_channel,
+ num_layers=layers_per_block[0],
+ resnet_eps=resnet_norm_eps,
+ inject_noise=inject_noise[0],
+ timestep_conditioning=timestep_conditioning,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ # up blocks
+ num_block_out_channels = len(block_out_channels)
+ self.up_blocks = nn.ModuleList([])
+ for i in range(num_block_out_channels):
+ input_channel = output_channel // upsample_factor[i]
+ output_channel = block_out_channels[i] // upsample_factor[i]
+
+ up_block = LTX2VideoUpBlock3d(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ num_layers=layers_per_block[i + 1],
+ resnet_eps=resnet_norm_eps,
+ spatio_temporal_scale=spatio_temporal_scaling[i],
+ inject_noise=inject_noise[i + 1],
+ timestep_conditioning=timestep_conditioning,
+ upsample_residual=upsample_residual[i],
+ upscale_factor=upsample_factor[i],
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.norm_out = PerChannelRMSNorm()
+ self.conv_act = nn.SiLU()
+ self.conv_out = LTX2VideoCausalConv3d(
+ in_channels=output_channel,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=1,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ # timestep embedding
+ self.time_embedder = None
+ self.scale_shift_table = None
+ self.timestep_scale_multiplier = None
+ if timestep_conditioning:
+ self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ causal: Optional[bool] = None,
+ ) -> torch.Tensor:
+ causal = causal or self.is_causal
+
+ hidden_states = self.conv_in(hidden_states, causal=causal)
+
+ if self.timestep_scale_multiplier is not None:
+ temb = temb * self.timestep_scale_multiplier
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal)
+
+ for up_block in self.up_blocks:
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal)
+ else:
+ hidden_states = self.mid_block(hidden_states, temb, causal=causal)
+
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states, temb, causal=causal)
+
+ hidden_states = self.norm_out(hidden_states)
+
+ if self.time_embedder is not None:
+ temb = self.time_embedder(
+ timestep=temb.flatten(),
+ resolution=None,
+ aspect_ratio=None,
+ batch_size=hidden_states.size(0),
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
+ temb = temb + self.scale_shift_table[None, ..., None, None, None]
+ shift, scale = temb.unbind(dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states, causal=causal)
+
+ p = self.patch_size
+ p_t = self.patch_size_t
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width)
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ return hidden_states
+
+
+class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
+ [LTX-2](https://huggingface.co/Lightricks/LTX-2).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Args:
+ in_channels (`int`, defaults to `3`):
+ Number of input channels.
+ out_channels (`int`, defaults to `3`):
+ Number of output channels.
+ latent_channels (`int`, defaults to `128`):
+ Number of latent channels.
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
+ The number of output channels for each block.
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
+ Whether a block should contain spatio-temporal downscaling or not.
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
+ The number of layers per block.
+ patch_size (`int`, defaults to `4`):
+ The size of spatial patches.
+ patch_size_t (`int`, defaults to `1`):
+ The size of temporal patches.
+ resnet_norm_eps (`float`, defaults to `1e-6`):
+ Epsilon value for ResNet normalization layers.
+ scaling_factor (`float`, *optional*, defaults to `1.0`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
+ encoder_causal (`bool`, defaults to `True`):
+ Whether the encoder should behave causally (future frames depend only on past frames) or not.
+ decoder_causal (`bool`, defaults to `False`):
+ Whether the decoder should behave causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ latent_channels: int = 128,
+ block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048),
+ down_block_types: Tuple[str, ...] = (
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoDownBlock3D",
+ ),
+ decoder_block_out_channels: Tuple[int, ...] = (256, 512, 1024),
+ layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2),
+ decoder_layers_per_block: Tuple[int, ...] = (5, 5, 5, 5),
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True),
+ decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True),
+ decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False),
+ downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
+ upsample_residual: Tuple[bool, ...] = (True, True, True),
+ upsample_factor: Tuple[int, ...] = (2, 2, 2),
+ timestep_conditioning: bool = False,
+ patch_size: int = 4,
+ patch_size_t: int = 1,
+ resnet_norm_eps: float = 1e-6,
+ scaling_factor: float = 1.0,
+ encoder_causal: bool = True,
+ decoder_causal: bool = True,
+ encoder_spatial_padding_mode: str = "zeros",
+ decoder_spatial_padding_mode: str = "reflect",
+ spatial_compression_ratio: int = None,
+ temporal_compression_ratio: int = None,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = LTX2VideoEncoder3d(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ block_out_channels=block_out_channels,
+ down_block_types=down_block_types,
+ spatio_temporal_scaling=spatio_temporal_scaling,
+ layers_per_block=layers_per_block,
+ downsample_type=downsample_type,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ resnet_norm_eps=resnet_norm_eps,
+ is_causal=encoder_causal,
+ spatial_padding_mode=encoder_spatial_padding_mode,
+ )
+ self.decoder = LTX2VideoDecoder3d(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=decoder_block_out_channels,
+ spatio_temporal_scaling=decoder_spatio_temporal_scaling,
+ layers_per_block=decoder_layers_per_block,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ resnet_norm_eps=resnet_norm_eps,
+ is_causal=decoder_causal,
+ timestep_conditioning=timestep_conditioning,
+ inject_noise=decoder_inject_noise,
+ upsample_residual=upsample_residual,
+ upsample_factor=upsample_factor,
+ spatial_padding_mode=decoder_spatial_padding_mode,
+ )
+
+ latents_mean = torch.zeros((latent_channels,), requires_grad=False)
+ latents_std = torch.ones((latent_channels,), requires_grad=False)
+ self.register_buffer("latents_mean", latents_mean, persistent=True)
+ self.register_buffer("latents_std", latents_std, persistent=True)
+
+ self.spatial_compression_ratio = (
+ patch_size * 2 ** sum(spatio_temporal_scaling)
+ if spatial_compression_ratio is None
+ else spatial_compression_ratio
+ )
+ self.temporal_compression_ratio = (
+ patch_size_t * 2 ** sum(spatio_temporal_scaling)
+ if temporal_compression_ratio is None
+ else temporal_compression_ratio
+ )
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
+ self.use_framewise_encoding = False
+ self.use_framewise_decoding = False
+
+ # This can be configured based on the amount of GPU memory available.
+ # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
+ # Setting it to higher values results in higher memory usage.
+ self.num_sample_frames_batch_size = 16
+ self.num_latent_frames_batch_size = 2
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 512
+ self.tile_sample_min_width = 512
+ self.tile_sample_min_num_frames = 16
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 448
+ self.tile_sample_stride_width = 448
+ self.tile_sample_stride_num_frames = 8
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_min_num_frames: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_sample_stride_num_frames: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
+
+ def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = x.shape
+
+ if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
+ return self._temporal_tiled_encode(x, causal=causal)
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x, causal=causal)
+
+ enc = self.encoder(x, causal=causal)
+
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x, causal=causal)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(
+ self,
+ z: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ causal: Optional[bool] = None,
+ return_dict: bool = True,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+
+ if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
+ return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict)
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict)
+
+ dec = self.decoder(z, temb, causal=causal)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(
+ self,
+ z: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ causal: Optional[bool] = None,
+ return_dict: bool = True,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ if temb is not None:
+ decoded_slices = [
+ self._decode(z_slice, t_slice, causal=causal).sample
+ for z_slice, t_slice in (z.split(1), temb.split(1))
+ ]
+ else:
+ decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z, temb, causal=causal).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ time = self.encoder(
+ x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width],
+ causal=causal,
+ )
+
+ row.append(time)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(
+ self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+
+ batch_size, num_channels, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ time = self.decoder(
+ z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal
+ )
+
+ row.append(time)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput:
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
+
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
+
+ row = []
+ for i in range(0, num_frames, self.tile_sample_stride_num_frames):
+ tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
+ if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
+ tile = self.tiled_encode(tile, causal=causal)
+ else:
+ tile = self.encoder(tile, causal=causal)
+ if i > 0:
+ tile = tile[:, :, 1:, :, :]
+ row.append(tile)
+
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
+ result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
+ else:
+ result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
+
+ enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
+ return enc
+
+ def _temporal_tiled_decode(
+ self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
+ blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
+
+ row = []
+ for i in range(0, num_frames, tile_latent_stride_num_frames):
+ tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
+ if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
+ decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample
+ else:
+ decoded = self.decoder(tile, temb, causal=causal)
+ if i > 0:
+ decoded = decoded[:, :, :-1, :, :]
+ row.append(decoded)
+
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
+ tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
+ result_row.append(tile)
+ else:
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
+
+ dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ sample_posterior: bool = False,
+ encoder_causal: Optional[bool] = None,
+ decoder_causal: Optional[bool] = None,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ x = sample
+ posterior = self.encode(x, causal=encoder_causal).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, temb, causal=decoder_causal)
+ if not return_dict:
+ return (dec.sample,)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
new file mode 100644
index 000000000000..6c9c7dce3d2f
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
@@ -0,0 +1,804 @@
+# Copyright 2025 The Lightricks team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils.accelerate_utils import apply_forward_hook
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
+
+
+LATENT_DOWNSAMPLE_FACTOR = 4
+
+
+class LTX2AudioCausalConv2d(nn.Module):
+ """
+ A causal 2D convolution that pads asymmetrically along the causal axis.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: int = 1,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ causality_axis: str = "height",
+ ) -> None:
+ super().__init__()
+
+ self.causality_axis = causality_axis
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+ dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
+
+ pad_h = (kernel_size[0] - 1) * dilation[0]
+ pad_w = (kernel_size[1] - 1) * dilation[1]
+
+ if self.causality_axis == "none":
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
+ elif self.causality_axis in {"width", "width-compatibility"}:
+ padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
+ elif self.causality_axis == "height":
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
+ else:
+ raise ValueError(f"Invalid causality_axis: {causality_axis}")
+
+ self.padding = padding
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=0,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.pad(x, self.padding)
+ return self.conv(x)
+
+
+class LTX2AudioPixelNorm(nn.Module):
+ """
+ Per-pixel (per-location) RMS normalization layer.
+ """
+
+ def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
+ rms = torch.sqrt(mean_sq + self.eps)
+ return x / rms
+
+
+class LTX2AudioAttnBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ norm_type: str = "group",
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+
+ if norm_type == "group":
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ elif norm_type == "pixel":
+ self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
+ else:
+ raise ValueError(f"Invalid normalization type: {norm_type}")
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(x)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ batch, channels, height, width = q.shape
+ q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
+ k = k.reshape(batch, channels, height * width).contiguous()
+ attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
+ attn = torch.nn.functional.softmax(attn, dim=2)
+
+ v = v.reshape(batch, channels, height * width)
+ attn = attn.permute(0, 2, 1).contiguous()
+ h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
+
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class LTX2AudioResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ conv_shortcut: bool = False,
+ dropout: float = 0.0,
+ temb_channels: int = 512,
+ norm_type: str = "group",
+ causality_axis: str = "height",
+ ) -> None:
+ super().__init__()
+ self.causality_axis = causality_axis
+
+ if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
+ raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ if norm_type == "group":
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ elif norm_type == "pixel":
+ self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
+ else:
+ raise ValueError(f"Invalid normalization type: {norm_type}")
+ self.non_linearity = nn.SiLU()
+ if causality_axis is not None:
+ self.conv1 = LTX2AudioCausalConv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
+ )
+ else:
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = nn.Linear(temb_channels, out_channels)
+ if norm_type == "group":
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ elif norm_type == "pixel":
+ self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
+ else:
+ raise ValueError(f"Invalid normalization type: {norm_type}")
+ self.dropout = nn.Dropout(dropout)
+ if causality_axis is not None:
+ self.conv2 = LTX2AudioCausalConv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
+ )
+ else:
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ if causality_axis is not None:
+ self.conv_shortcut = LTX2AudioCausalConv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
+ )
+ else:
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ if causality_axis is not None:
+ self.nin_shortcut = LTX2AudioCausalConv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
+ )
+ else:
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ h = self.norm1(x)
+ h = self.non_linearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = self.non_linearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
+
+ return x + h
+
+
+class LTX2AudioDownsample(nn.Module):
+ def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
+ super().__init__()
+ self.with_conv = with_conv
+ self.causality_axis = causality_axis
+
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.with_conv:
+ # Padding tuple is in the order: (left, right, top, bottom).
+ if self.causality_axis == "none":
+ pad = (0, 1, 0, 1)
+ elif self.causality_axis == "width":
+ pad = (2, 0, 0, 1)
+ elif self.causality_axis == "height":
+ pad = (0, 1, 2, 0)
+ elif self.causality_axis == "width-compatibility":
+ pad = (1, 0, 0, 1)
+ else:
+ raise ValueError(
+ f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
+ f" and `width-compatibility`."
+ )
+
+ x = F.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ # with_conv=False implies that causality_axis is "none"
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class LTX2AudioUpsample(nn.Module):
+ def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
+ super().__init__()
+ self.with_conv = with_conv
+ self.causality_axis = causality_axis
+ if self.with_conv:
+ if causality_axis is not None:
+ self.conv = LTX2AudioCausalConv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
+ )
+ else:
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ if self.causality_axis is None or self.causality_axis == "none":
+ pass
+ elif self.causality_axis == "height":
+ x = x[:, :, 1:, :]
+ elif self.causality_axis == "width":
+ x = x[:, :, :, 1:]
+ elif self.causality_axis == "width-compatibility":
+ pass
+ else:
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
+
+ return x
+
+
+class LTX2AudioAudioPatchifier:
+ """
+ Patchifier for spectrogram/audio latents.
+ """
+
+ def __init__(
+ self,
+ patch_size: int,
+ sample_rate: int = 16000,
+ hop_length: int = 160,
+ audio_latent_downsample_factor: int = 4,
+ is_causal: bool = True,
+ ):
+ self.hop_length = hop_length
+ self.sample_rate = sample_rate
+ self.audio_latent_downsample_factor = audio_latent_downsample_factor
+ self.is_causal = is_causal
+ self._patch_size = (1, patch_size, patch_size)
+
+ def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
+ batch, channels, time, freq = audio_latents.shape
+ return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
+
+ def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
+ batch, time, _ = audio_latents.shape
+ return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
+
+ @property
+ def patch_size(self) -> Tuple[int, int, int]:
+ return self._patch_size
+
+
+class LTX2AudioEncoder(nn.Module):
+ def __init__(
+ self,
+ base_channels: int = 128,
+ output_channels: int = 1,
+ num_res_blocks: int = 2,
+ attn_resolutions: Optional[Tuple[int, ...]] = None,
+ in_channels: int = 2,
+ resolution: int = 256,
+ latent_channels: int = 8,
+ ch_mult: Tuple[int, ...] = (1, 2, 4),
+ norm_type: str = "group",
+ causality_axis: Optional[str] = "width",
+ dropout: float = 0.0,
+ mid_block_add_attention: bool = False,
+ sample_rate: int = 16000,
+ mel_hop_length: int = 160,
+ is_causal: bool = True,
+ mel_bins: Optional[int] = 64,
+ double_z: bool = True,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.mel_hop_length = mel_hop_length
+ self.is_causal = is_causal
+ self.mel_bins = mel_bins
+
+ self.base_channels = base_channels
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.out_ch = output_channels
+ self.give_pre_end = False
+ self.tanh_out = False
+ self.norm_type = norm_type
+ self.latent_channels = latent_channels
+ self.channel_multipliers = ch_mult
+ self.attn_resolutions = attn_resolutions
+ self.causality_axis = causality_axis
+
+ base_block_channels = base_channels
+ base_resolution = resolution
+ self.z_shape = (1, latent_channels, base_resolution, base_resolution)
+
+ if self.causality_axis is not None:
+ self.conv_in = LTX2AudioCausalConv2d(
+ in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
+ )
+ else:
+ self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
+
+ self.down = nn.ModuleList()
+ block_in = base_block_channels
+ curr_res = self.resolution
+
+ for level in range(self.num_resolutions):
+ stage = nn.Module()
+ stage.block = nn.ModuleList()
+ stage.attn = nn.ModuleList()
+ block_out = self.base_channels * self.channel_multipliers[level]
+
+ for _ in range(self.num_res_blocks):
+ stage.block.append(
+ LTX2AudioResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ norm_type=self.norm_type,
+ causality_axis=self.causality_axis,
+ )
+ )
+ block_in = block_out
+ if self.attn_resolutions:
+ if curr_res in self.attn_resolutions:
+ stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
+
+ if level != self.num_resolutions - 1:
+ stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
+ curr_res = curr_res // 2
+
+ self.down.append(stage)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = LTX2AudioResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ norm_type=self.norm_type,
+ causality_axis=self.causality_axis,
+ )
+ if mid_block_add_attention:
+ self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
+ else:
+ self.mid.attn_1 = nn.Identity()
+ self.mid.block_2 = LTX2AudioResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ norm_type=self.norm_type,
+ causality_axis=self.causality_axis,
+ )
+
+ final_block_channels = block_in
+ z_channels = 2 * latent_channels if double_z else latent_channels
+ if self.norm_type == "group":
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
+ elif self.norm_type == "pixel":
+ self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
+ else:
+ raise ValueError(f"Invalid normalization type: {self.norm_type}")
+ self.non_linearity = nn.SiLU()
+
+ if self.causality_axis is not None:
+ self.conv_out = LTX2AudioCausalConv2d(
+ final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
+ )
+ else:
+ self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
+ hidden_states = self.conv_in(hidden_states)
+
+ for level in range(self.num_resolutions):
+ stage = self.down[level]
+ for block_idx, block in enumerate(stage.block):
+ hidden_states = block(hidden_states, temb=None)
+ if stage.attn:
+ hidden_states = stage.attn[block_idx](hidden_states)
+
+ if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
+ hidden_states = stage.downsample(hidden_states)
+
+ hidden_states = self.mid.block_1(hidden_states, temb=None)
+ hidden_states = self.mid.attn_1(hidden_states)
+ hidden_states = self.mid.block_2(hidden_states, temb=None)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.non_linearity(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class LTX2AudioDecoder(nn.Module):
+ """
+ Symmetric decoder that reconstructs audio spectrograms from latent features.
+
+ The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
+ convolutions.
+ """
+
+ def __init__(
+ self,
+ base_channels: int = 128,
+ output_channels: int = 1,
+ num_res_blocks: int = 2,
+ attn_resolutions: Optional[Tuple[int, ...]] = None,
+ in_channels: int = 2,
+ resolution: int = 256,
+ latent_channels: int = 8,
+ ch_mult: Tuple[int, ...] = (1, 2, 4),
+ norm_type: str = "group",
+ causality_axis: Optional[str] = "width",
+ dropout: float = 0.0,
+ mid_block_add_attention: bool = False,
+ sample_rate: int = 16000,
+ mel_hop_length: int = 160,
+ is_causal: bool = True,
+ mel_bins: Optional[int] = 64,
+ ) -> None:
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.mel_hop_length = mel_hop_length
+ self.is_causal = is_causal
+ self.mel_bins = mel_bins
+ self.patchifier = LTX2AudioAudioPatchifier(
+ patch_size=1,
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
+ sample_rate=sample_rate,
+ hop_length=mel_hop_length,
+ is_causal=is_causal,
+ )
+
+ self.base_channels = base_channels
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.out_ch = output_channels
+ self.give_pre_end = False
+ self.tanh_out = False
+ self.norm_type = norm_type
+ self.latent_channels = latent_channels
+ self.channel_multipliers = ch_mult
+ self.attn_resolutions = attn_resolutions
+ self.causality_axis = causality_axis
+
+ base_block_channels = base_channels * self.channel_multipliers[-1]
+ base_resolution = resolution // (2 ** (self.num_resolutions - 1))
+ self.z_shape = (1, latent_channels, base_resolution, base_resolution)
+
+ if self.causality_axis is not None:
+ self.conv_in = LTX2AudioCausalConv2d(
+ latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
+ )
+ else:
+ self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
+ self.non_linearity = nn.SiLU()
+ self.mid = nn.Module()
+ self.mid.block_1 = LTX2AudioResnetBlock(
+ in_channels=base_block_channels,
+ out_channels=base_block_channels,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ norm_type=self.norm_type,
+ causality_axis=self.causality_axis,
+ )
+ if mid_block_add_attention:
+ self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
+ else:
+ self.mid.attn_1 = nn.Identity()
+ self.mid.block_2 = LTX2AudioResnetBlock(
+ in_channels=base_block_channels,
+ out_channels=base_block_channels,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ norm_type=self.norm_type,
+ causality_axis=self.causality_axis,
+ )
+
+ self.up = nn.ModuleList()
+ block_in = base_block_channels
+ curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
+
+ for level in reversed(range(self.num_resolutions)):
+ stage = nn.Module()
+ stage.block = nn.ModuleList()
+ stage.attn = nn.ModuleList()
+ block_out = self.base_channels * self.channel_multipliers[level]
+
+ for _ in range(self.num_res_blocks + 1):
+ stage.block.append(
+ LTX2AudioResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ norm_type=self.norm_type,
+ causality_axis=self.causality_axis,
+ )
+ )
+ block_in = block_out
+ if self.attn_resolutions:
+ if curr_res in self.attn_resolutions:
+ stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
+
+ if level != 0:
+ stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
+ curr_res *= 2
+
+ self.up.insert(0, stage)
+
+ final_block_channels = block_in
+
+ if self.norm_type == "group":
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
+ elif self.norm_type == "pixel":
+ self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
+ else:
+ raise ValueError(f"Invalid normalization type: {self.norm_type}")
+
+ if self.causality_axis is not None:
+ self.conv_out = LTX2AudioCausalConv2d(
+ final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
+ )
+ else:
+ self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ ) -> torch.Tensor:
+ _, _, frames, mel_bins = sample.shape
+
+ target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
+
+ if self.causality_axis is not None:
+ target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
+
+ target_channels = self.out_ch
+ target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
+
+ hidden_features = self.conv_in(sample)
+ hidden_features = self.mid.block_1(hidden_features, temb=None)
+ hidden_features = self.mid.attn_1(hidden_features)
+ hidden_features = self.mid.block_2(hidden_features, temb=None)
+
+ for level in reversed(range(self.num_resolutions)):
+ stage = self.up[level]
+ for block_idx, block in enumerate(stage.block):
+ hidden_features = block(hidden_features, temb=None)
+ if stage.attn:
+ hidden_features = stage.attn[block_idx](hidden_features)
+
+ if level != 0 and hasattr(stage, "upsample"):
+ hidden_features = stage.upsample(hidden_features)
+
+ if self.give_pre_end:
+ return hidden_features
+
+ hidden = self.norm_out(hidden_features)
+ hidden = self.non_linearity(hidden)
+ decoded_output = self.conv_out(hidden)
+ decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
+
+ _, _, current_time, current_freq = decoded_output.shape
+ target_time = target_frames
+ target_freq = target_mel_bins
+
+ decoded_output = decoded_output[
+ :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
+ ]
+
+ time_padding_needed = target_time - decoded_output.shape[2]
+ freq_padding_needed = target_freq - decoded_output.shape[3]
+
+ if time_padding_needed > 0 or freq_padding_needed > 0:
+ padding = (
+ 0,
+ max(freq_padding_needed, 0),
+ 0,
+ max(time_padding_needed, 0),
+ )
+ decoded_output = F.pad(decoded_output, padding)
+
+ decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
+
+ return decoded_output
+
+
+class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
+ r"""
+ LTX2 audio VAE for encoding and decoding audio latent representations.
+ """
+
+ _supports_gradient_checkpointing = False
+
+ @register_to_config
+ def __init__(
+ self,
+ base_channels: int = 128,
+ output_channels: int = 2,
+ ch_mult: Tuple[int, ...] = (1, 2, 4),
+ num_res_blocks: int = 2,
+ attn_resolutions: Optional[Tuple[int, ...]] = None,
+ in_channels: int = 2,
+ resolution: int = 256,
+ latent_channels: int = 8,
+ norm_type: str = "pixel",
+ causality_axis: Optional[str] = "height",
+ dropout: float = 0.0,
+ mid_block_add_attention: bool = False,
+ sample_rate: int = 16000,
+ mel_hop_length: int = 160,
+ is_causal: bool = True,
+ mel_bins: Optional[int] = 64,
+ double_z: bool = True,
+ ) -> None:
+ super().__init__()
+
+ supported_causality_axes = {"none", "width", "height", "width-compatibility"}
+ if causality_axis not in supported_causality_axes:
+ raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
+
+ attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
+
+ self.encoder = LTX2AudioEncoder(
+ base_channels=base_channels,
+ output_channels=output_channels,
+ ch_mult=ch_mult,
+ num_res_blocks=num_res_blocks,
+ attn_resolutions=attn_resolution_set,
+ in_channels=in_channels,
+ resolution=resolution,
+ latent_channels=latent_channels,
+ norm_type=norm_type,
+ causality_axis=causality_axis,
+ dropout=dropout,
+ mid_block_add_attention=mid_block_add_attention,
+ sample_rate=sample_rate,
+ mel_hop_length=mel_hop_length,
+ is_causal=is_causal,
+ mel_bins=mel_bins,
+ double_z=double_z,
+ )
+
+ self.decoder = LTX2AudioDecoder(
+ base_channels=base_channels,
+ output_channels=output_channels,
+ ch_mult=ch_mult,
+ num_res_blocks=num_res_blocks,
+ attn_resolutions=attn_resolution_set,
+ in_channels=in_channels,
+ resolution=resolution,
+ latent_channels=latent_channels,
+ norm_type=norm_type,
+ causality_axis=causality_axis,
+ dropout=dropout,
+ mid_block_add_attention=mid_block_add_attention,
+ sample_rate=sample_rate,
+ mel_hop_length=mel_hop_length,
+ is_causal=is_causal,
+ mel_bins=mel_bins,
+ )
+
+ # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
+ # the entire dataset and stored in model's checkpoint under AudioVAE state_dict
+ latents_std = torch.zeros((base_channels,))
+ latents_mean = torch.ones((base_channels,))
+ self.register_buffer("latents_mean", latents_mean, persistent=True)
+ self.register_buffer("latents_std", latents_std, persistent=True)
+
+ # TODO: calculate programmatically instead of hardcoding
+ self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
+ # TODO: confirm whether the mel compression ratio below is correct
+ self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ return self.encoder(x)
+
+ @apply_forward_hook
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
+ return self.decoder(z)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z)
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ posterior = self.encode(sample).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ if not return_dict:
+ return (dec.sample,)
+ return dec
diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py
index f4ad1af278f5..153608bb2bf8 100644
--- a/src/diffusers/models/cache_utils.py
+++ b/src/diffusers/models/cache_utils.py
@@ -41,9 +41,11 @@ def enable_cache(self, config) -> None:
Enable caching techniques on the model.
Args:
- config (`Union[PyramidAttentionBroadcastConfig]`):
+ config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
+ - [`~hooks.FasterCacheConfig`]
+ - [`~hooks.FirstBlockCacheConfig`]
Example:
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
deleted file mode 100644
index c18bd8751dcb..000000000000
--- a/src/diffusers/models/controlnet.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2025 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Optional, Tuple, Union
-
-from ..utils import deprecate
-from .controlnets.controlnet import ( # noqa
- ControlNetConditioningEmbedding,
- ControlNetModel,
- ControlNetOutput,
- zero_module,
-)
-
-
-class ControlNetOutput(ControlNetOutput):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
- deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
- super().__init__(*args, **kwargs)
-
-
-class ControlNetModel(ControlNetModel):
- def __init__(
- self,
- in_channels: int = 4,
- conditioning_channels: int = 3,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str, ...] = (
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "DownBlock2D",
- ),
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
- layers_per_block: int = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: Optional[int] = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: int = 1280,
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
- encoder_hid_dim: Optional[int] = None,
- encoder_hid_dim_type: Optional[str] = None,
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
- use_linear_projection: bool = False,
- class_embed_type: Optional[str] = None,
- addition_embed_type: Optional[str] = None,
- addition_time_embed_dim: Optional[int] = None,
- num_class_embeds: Optional[int] = None,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- projection_class_embeddings_input_dim: Optional[int] = None,
- controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- global_pool_conditions: bool = False,
- addition_embed_type_num_heads: int = 64,
- ):
- deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
- deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
- super().__init__(
- in_channels=in_channels,
- conditioning_channels=conditioning_channels,
- flip_sin_to_cos=flip_sin_to_cos,
- freq_shift=freq_shift,
- down_block_types=down_block_types,
- mid_block_type=mid_block_type,
- only_cross_attention=only_cross_attention,
- block_out_channels=block_out_channels,
- layers_per_block=layers_per_block,
- downsample_padding=downsample_padding,
- mid_block_scale_factor=mid_block_scale_factor,
- act_fn=act_fn,
- norm_num_groups=norm_num_groups,
- norm_eps=norm_eps,
- cross_attention_dim=cross_attention_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- encoder_hid_dim=encoder_hid_dim,
- encoder_hid_dim_type=encoder_hid_dim_type,
- attention_head_dim=attention_head_dim,
- num_attention_heads=num_attention_heads,
- use_linear_projection=use_linear_projection,
- class_embed_type=class_embed_type,
- addition_embed_type=addition_embed_type,
- addition_time_embed_dim=addition_time_embed_dim,
- num_class_embeds=num_class_embeds,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- global_pool_conditions=global_pool_conditions,
- addition_embed_type_num_heads=addition_embed_type_num_heads,
- )
-
-
-class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
- deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
- super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py
deleted file mode 100644
index e82748436d86..000000000000
--- a/src/diffusers/models/controlnet_flux.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from typing import List
-
-from ..utils import deprecate, logging
-from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-class FluxControlNetOutput(FluxControlNetOutput):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
- deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
- super().__init__(*args, **kwargs)
-
-
-class FluxControlNetModel(FluxControlNetModel):
- def __init__(
- self,
- patch_size: int = 1,
- in_channels: int = 64,
- num_layers: int = 19,
- num_single_layers: int = 38,
- attention_head_dim: int = 128,
- num_attention_heads: int = 24,
- joint_attention_dim: int = 4096,
- pooled_projection_dim: int = 768,
- guidance_embeds: bool = False,
- axes_dims_rope: List[int] = [16, 56, 56],
- num_mode: int = None,
- conditioning_embedding_channels: int = None,
- ):
- deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
- deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
- super().__init__(
- patch_size=patch_size,
- in_channels=in_channels,
- num_layers=num_layers,
- num_single_layers=num_single_layers,
- attention_head_dim=attention_head_dim,
- num_attention_heads=num_attention_heads,
- joint_attention_dim=joint_attention_dim,
- pooled_projection_dim=pooled_projection_dim,
- guidance_embeds=guidance_embeds,
- axes_dims_rope=axes_dims_rope,
- num_mode=num_mode,
- conditioning_embedding_channels=conditioning_embedding_channels,
- )
-
-
-class FluxMultiControlNetModel(FluxMultiControlNetModel):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
- deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
- super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py
deleted file mode 100644
index d239ad4eb3e8..000000000000
--- a/src/diffusers/models/controlnet_sd3.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from ..utils import deprecate, logging
-from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-class SD3ControlNetOutput(SD3ControlNetOutput):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
- deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
- super().__init__(*args, **kwargs)
-
-
-class SD3ControlNetModel(SD3ControlNetModel):
- def __init__(
- self,
- sample_size: int = 128,
- patch_size: int = 2,
- in_channels: int = 16,
- num_layers: int = 18,
- attention_head_dim: int = 64,
- num_attention_heads: int = 18,
- joint_attention_dim: int = 4096,
- caption_projection_dim: int = 1152,
- pooled_projection_dim: int = 2048,
- out_channels: int = 16,
- pos_embed_max_size: int = 96,
- extra_conditioning_channels: int = 0,
- ):
- deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
- deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
- super().__init__(
- sample_size=sample_size,
- patch_size=patch_size,
- in_channels=in_channels,
- num_layers=num_layers,
- attention_head_dim=attention_head_dim,
- num_attention_heads=num_attention_heads,
- joint_attention_dim=joint_attention_dim,
- caption_projection_dim=caption_projection_dim,
- pooled_projection_dim=pooled_projection_dim,
- out_channels=out_channels,
- pos_embed_max_size=pos_embed_max_size,
- extra_conditioning_channels=extra_conditioning_channels,
- )
-
-
-class SD3MultiControlNetModel(SD3MultiControlNetModel):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
- deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
- super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py
deleted file mode 100644
index 5c67af4fe9c1..000000000000
--- a/src/diffusers/models/controlnet_sparsectrl.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Copyright 2025 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from typing import Optional, Tuple, Union
-
-from ..utils import deprecate, logging
-from .controlnets.controlnet_sparsectrl import ( # noqa
- SparseControlNetConditioningEmbedding,
- SparseControlNetModel,
- SparseControlNetOutput,
- zero_module,
-)
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-class SparseControlNetOutput(SparseControlNetOutput):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
- deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
- super().__init__(*args, **kwargs)
-
-
-class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
- def __init__(self, *args, **kwargs):
- deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
- deprecate(
- "diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
- )
- super().__init__(*args, **kwargs)
-
-
-class SparseControlNetModel(SparseControlNetModel):
- def __init__(
- self,
- in_channels: int = 4,
- conditioning_channels: int = 4,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str, ...] = (
- "CrossAttnDownBlockMotion",
- "CrossAttnDownBlockMotion",
- "CrossAttnDownBlockMotion",
- "DownBlockMotion",
- ),
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
- layers_per_block: int = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: Optional[int] = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: int = 768,
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
- transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
- temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
- use_linear_projection: bool = False,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- global_pool_conditions: bool = False,
- controlnet_conditioning_channel_order: str = "rgb",
- motion_max_seq_length: int = 32,
- motion_num_attention_heads: int = 8,
- concat_conditioning_mask: bool = True,
- use_simplified_condition_embedding: bool = True,
- ):
- deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
- deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
- super().__init__(
- in_channels=in_channels,
- conditioning_channels=conditioning_channels,
- flip_sin_to_cos=flip_sin_to_cos,
- freq_shift=freq_shift,
- down_block_types=down_block_types,
- only_cross_attention=only_cross_attention,
- block_out_channels=block_out_channels,
- layers_per_block=layers_per_block,
- downsample_padding=downsample_padding,
- mid_block_scale_factor=mid_block_scale_factor,
- act_fn=act_fn,
- norm_num_groups=norm_num_groups,
- norm_eps=norm_eps,
- cross_attention_dim=cross_attention_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- transformer_layers_per_mid_block=transformer_layers_per_mid_block,
- temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
- attention_head_dim=attention_head_dim,
- num_attention_heads=num_attention_heads,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- global_pool_conditions=global_pool_conditions,
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
- motion_max_seq_length=motion_max_seq_length,
- motion_num_attention_heads=motion_num_attention_heads,
- concat_conditioning_mask=concat_conditioning_mask,
- use_simplified_condition_embedding=use_simplified_condition_embedding,
- )
diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py
index 86971271788f..fa374285eec1 100644
--- a/src/diffusers/models/controlnets/controlnet_qwenimage.py
+++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py
@@ -20,7 +20,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
@@ -31,6 +31,7 @@
QwenImageTransformerBlock,
QwenTimestepProjEmbeddings,
RMSNorm,
+ compute_text_seq_len_from_mask,
)
@@ -136,7 +137,7 @@ def forward(
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
- The [`FluxTransformer2DModel`] forward method.
+ The [`QwenImageControlNetModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -147,24 +148,39 @@ def forward(
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
- from the embeddings of input conditions.
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
+ Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
+ Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
+ (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
- A list of tensors that if specified are added to the residuals of transformer blocks.
+ img_shapes (`List[Tuple[int, int, int]]`, *optional*):
+ Image shapes for RoPE computation.
+ txt_seq_lens (`List[int]`, *optional*):
+ **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
+ length.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
- tuple.
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
+ If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
+ the first element is the controlnet block samples.
"""
+ # Handle deprecated txt_seq_lens parameter
+ if txt_seq_lens is not None:
+ deprecate(
+ "txt_seq_lens",
+ "0.39.0",
+ "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
+ "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
+ "and `encoder_hidden_states_mask`.",
+ standard_warn=False,
+ )
+
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
@@ -186,32 +202,47 @@ def forward(
temb = self.time_text_embed(timestep, hidden_states)
- image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
+ # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
+ text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
+ encoder_hidden_states, encoder_hidden_states_mask
+ )
+
+ image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
+ # Construct joint attention mask once to avoid reconstructing in every block
+ block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
+ if encoder_hidden_states_mask is not None:
+ # Build joint mask: [text_mask, all_ones_for_image]
+ batch_size, image_seq_len = hidden_states.shape[:2]
+ image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
+ joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
+ block_attention_kwargs["attention_mask"] = joint_attention_mask
+
block_samples = ()
- for index_block, block in enumerate(self.transformer_blocks):
+ for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
- encoder_hidden_states_mask,
+ None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
temb,
image_rotary_emb,
+ block_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
- encoder_hidden_states_mask=encoder_hidden_states_mask,
+ encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
temb=temb,
image_rotary_emb=image_rotary_emb,
- joint_attention_kwargs=joint_attention_kwargs,
+ joint_attention_kwargs=block_attention_kwargs,
)
block_samples = block_samples + (hidden_states,)
@@ -267,6 +298,15 @@ def forward(
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[QwenImageControlNetOutput, Tuple]:
+ if txt_seq_lens is not None:
+ deprecate(
+ "txt_seq_lens",
+ "0.39.0",
+ "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
+ "removed in version 0.39.0. The text sequence length is now automatically inferred from "
+ "`encoder_hidden_states` and `encoder_hidden_states_mask`.",
+ standard_warn=False,
+ )
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1:
@@ -281,7 +321,6 @@ def forward(
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index 8b48ba6b4873..6d2e8df9c286 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -47,6 +47,7 @@
is_torch_version,
logging,
)
+from ..utils.distributed_utils import is_torch_dist_rank_zero
logger = logging.get_logger(__name__)
@@ -354,8 +355,9 @@ def _load_shard_file(
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
+ disable_mmap=False,
):
- state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
+ state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
@@ -401,6 +403,7 @@ def _load_shard_files_with_threadpool(
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
+ disable_mmap=False,
):
# Do not spawn anymore workers than you need
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
@@ -427,10 +430,15 @@ def _load_shard_files_with_threadpool(
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
+ disable_mmap=disable_mmap,
)
+ tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
+ if not is_torch_dist_rank_zero():
+ tqdm_kwargs["disable"] = True
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
- with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
+ with logging.tqdm(**tqdm_kwargs) as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 41da95d3a2a2..63e50af61771 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -59,11 +59,8 @@
is_torch_version,
logging,
)
-from ..utils.hub_utils import (
- PushToHubMixin,
- load_or_create_model_card,
- populate_model_card,
-)
+from ..utils.distributed_utils import is_torch_dist_rank_zero
+from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
@@ -602,6 +599,7 @@ def set_attention_backend(self, backend: str) -> None:
from .attention import AttentionModuleMixin
from .attention_dispatch import (
AttentionBackendName,
+ _AttentionBackendRegistry,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)
@@ -610,6 +608,16 @@ def set_attention_backend(self, backend: str) -> None:
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+
+ parallel_config_set = False
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if getattr(processor, "_parallel_config", None) is not None:
+ parallel_config_set = True
+ break
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
@@ -617,10 +625,17 @@ def set_attention_backend(self, backend: str) -> None:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
+ if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
+ compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
+ raise ValueError(
+ f"Context parallelism is enabled but current attention backend '{backend.value}' "
+ f"does not support context parallelism. "
+ f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`."
+ )
+
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)
- attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
@@ -629,6 +644,9 @@ def set_attention_backend(self, backend: str) -> None:
continue
processor._attention_backend = backend
+ # Important to set the active backend so that it propagates gracefully throughout.
+ _AttentionBackendRegistry.set_active_backend(backend)
+
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
@@ -1309,6 +1327,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
+ disable_mmap=disable_mmap,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1363,12 +1382,12 @@ def cuda(self, *args, **kwargs):
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
- if getattr(self, "is_loaded_in_8bit", False):
+ if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
raise ValueError(
- "Calling `cuda()` is not supported for `8-bit` quantized models. "
- " Please use the model as it is, since the model has already been set to the correct devices."
+ "Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
)
- elif is_bitsandbytes_version("<", "0.43.2"):
+ elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
@@ -1415,17 +1434,16 @@ def to(self, *args, **kwargs):
)
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
- if getattr(self, "is_loaded_in_8bit", False):
+ if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
raise ValueError(
- "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
- " model has already been set to the correct devices and casted to the correct `dtype`."
+ "Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
)
- elif is_bitsandbytes_version("<", "0.43.2"):
+ elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
raise ValueError(
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
-
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
logger.warning(
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
@@ -1541,7 +1559,7 @@ def enable_parallelism(
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
- f"calling `enable_parallelism()`."
+ f"calling `model.enable_parallelism()`."
)
# All modules use the same attention processor and backend. We don't need to
@@ -1595,6 +1613,7 @@ def _load_pretrained_model(
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
is_parallel_loading_enabled: Optional[bool] = False,
+ disable_mmap: bool = False,
):
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
@@ -1663,6 +1682,7 @@ def _load_pretrained_model(
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
+ disable_mmap=disable_mmap,
)
if is_parallel_loading_enabled:
@@ -1672,7 +1692,10 @@ def _load_pretrained_model(
else:
shard_files = resolved_model_file
if len(resolved_model_file) > 1:
- shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
+ shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
+ if not is_torch_dist_rank_zero():
+ shard_tqdm_kwargs["disable"] = True
+ shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
for shard_file in shard_files:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 40b5d4a0dfc9..d9d1b27a1e40 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -27,6 +27,7 @@
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
+ from .transformer_glm_image import GlmImageTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
@@ -35,6 +36,7 @@
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
+ from .transformer_ltx2 import LTX2VideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 16c526f437f2..1a4464432425 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -22,7 +22,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,11 +717,7 @@ def forward(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
- if is_torch_npu_available():
- freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
- image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
- else:
- image_rotary_emb = self.pos_embed(ids)
+ image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py
index c10bf3ed4f7b..9cadfcefc497 100644
--- a/src/diffusers/models/transformers/transformer_flux2.py
+++ b/src/diffusers/models/transformers/transformer_flux2.py
@@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -585,7 +585,13 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
class Flux2TimestepGuidanceEmbeddings(nn.Module):
- def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
+ def __init__(
+ self,
+ in_channels: int = 256,
+ embedding_dim: int = 6144,
+ bias: bool = False,
+ guidance_embeds: bool = True,
+ ):
super().__init__()
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -593,20 +599,24 @@ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
- self.guidance_embedder = TimestepEmbedding(
- in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
- )
+ if guidance_embeds:
+ self.guidance_embedder = TimestepEmbedding(
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
+ )
+ else:
+ self.guidance_embedder = None
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
- guidance_proj = self.time_proj(guidance)
- guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
-
- time_guidance_emb = timesteps_emb + guidance_emb
-
- return time_guidance_emb
+ if guidance is not None and self.guidance_embedder is not None:
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
+ time_guidance_emb = timesteps_emb + guidance_emb
+ return time_guidance_emb
+ else:
+ return timesteps_emb
class Flux2Modulation(nn.Module):
@@ -698,6 +708,7 @@ def __init__(
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
rope_theta: int = 2000,
eps: float = 1e-6,
+ guidance_embeds: bool = True,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -708,7 +719,10 @@ def __init__(
# 2. Combined timestep + guidance embedding
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
- in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
+ in_channels=timestep_guidance_channels,
+ embedding_dim=self.inner_dim,
+ bias=False,
+ guidance_embeds=guidance_embeds,
)
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
@@ -815,7 +829,9 @@ def forward(
# 1. Calculate timestep embedding and modulation parameters
timestep = timestep.to(hidden_states.dtype) * 1000
- guidance = guidance.to(hidden_states.dtype) * 1000
+
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_guidance_embed(timestep, guidance)
@@ -835,14 +851,8 @@ def forward(
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]
- if is_torch_npu_available():
- freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
- image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
- freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
- text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
- else:
- image_rotary_emb = self.pos_embed(img_ids)
- text_rotary_emb = self.pos_embed(txt_ids)
+ image_rotary_emb = self.pos_embed(img_ids)
+ text_rotary_emb = self.pos_embed(txt_ids)
concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py
new file mode 100644
index 000000000000..b7b3aa391ce4
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_glm_image.py
@@ -0,0 +1,621 @@
+# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import logging
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import LayerNorm, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
+ def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
+ self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ target_size: torch.Tensor,
+ crop_coords: torch.Tensor,
+ hidden_dtype: torch.dtype,
+ ) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+
+ crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
+ target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
+
+ # (B, 2 * condition_dim)
+ condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)
+
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
+ condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
+
+ conditioning = timesteps_emb + condition_emb
+ conditioning = F.silu(conditioning)
+
+ return conditioning
+
+
+class GlmImageImageProjector(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ hidden_size: int = 2560,
+ patch_size: int = 2,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, channel, height, width = hidden_states.shape
+ post_patch_height = height // self.patch_size
+ post_patch_width = width // self.patch_size
+
+ hidden_states = hidden_states.reshape(
+ batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
+ )
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
+ hidden_states = self.proj(hidden_states)
+
+ return hidden_states
+
+
+class GlmImageAdaLayerNormZero(nn.Module):
+ def __init__(self, embedding_dim: int, dim: int) -> None:
+ super().__init__()
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
+
+ def forward(
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ dtype = hidden_states.dtype
+ norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
+ norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
+
+ emb = self.linear(temb)
+ (
+ shift_msa,
+ c_shift_msa,
+ scale_msa,
+ c_scale_msa,
+ gate_msa,
+ c_gate_msa,
+ shift_mlp,
+ c_shift_mlp,
+ scale_mlp,
+ c_scale_mlp,
+ gate_mlp,
+ c_gate_mlp,
+ ) = emb.chunk(12, dim=1)
+
+ hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+ encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
+
+ return (
+ hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ )
+
+
+class GlmImageLayerKVCache:
+ """KV cache for GlmImage model."""
+
+ def __init__(self):
+ self.k_cache = None
+ self.v_cache = None
+ self.mode: Optional[str] = None # "write", "read", "skip"
+
+ def store(self, k: torch.Tensor, v: torch.Tensor):
+ if self.k_cache is None:
+ self.k_cache = k
+ self.v_cache = v
+ else:
+ self.k_cache = torch.cat([self.k_cache, k], dim=1)
+ self.v_cache = torch.cat([self.v_cache, v], dim=1)
+
+ def get(self, k: torch.Tensor, v: torch.Tensor):
+ if self.k_cache.shape[0] != k.shape[0]:
+ k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
+ v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
+ else:
+ k_cache_expanded = self.k_cache
+ v_cache_expanded = self.v_cache
+
+ k_cache = torch.cat([k_cache_expanded, k], dim=1)
+ v_cache = torch.cat([v_cache_expanded, v], dim=1)
+ return k_cache, v_cache
+
+ def clear(self):
+ self.k_cache = None
+ self.v_cache = None
+ self.mode = None
+
+
+class GlmImageKVCache:
+ """Container for all layers' KV caches."""
+
+ def __init__(self, num_layers: int):
+ self.num_layers = num_layers
+ self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
+
+ def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
+ return self.caches[layer_idx]
+
+ def set_mode(self, mode: Optional[str]):
+ if mode is not None and mode not in ["write", "read", "skip"]:
+ raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
+ for cache in self.caches:
+ cache.mode = mode
+
+ def clear(self):
+ for cache in self.caches:
+ cache.clear()
+
+
+class GlmImageAttnProcessor:
+ """
+ Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+
+ The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
+ text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ kv_cache: Optional[GlmImageLayerKVCache] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ dtype = encoder_hidden_states.dtype
+
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query).to(dtype=dtype)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key).to(dtype=dtype)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ query[:, text_seq_length:, :, :] = apply_rotary_emb(
+ query[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2
+ )
+ key[:, text_seq_length:, :, :] = apply_rotary_emb(
+ key[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2
+ )
+
+ if kv_cache is not None:
+ if kv_cache.mode == "write":
+ kv_cache.store(key, value)
+ elif kv_cache.mode == "read":
+ key, value = kv_cache.get(key, value)
+ elif kv_cache.mode == "skip":
+ pass
+
+ # 4. Attention
+ if attention_mask is not None:
+ text_attn_mask = attention_mask
+ assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
+ text_attn_mask = text_attn_mask.float().to(query.device)
+ mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
+ mix_attn_mask[:, :text_seq_length] = text_attn_mask
+ mix_attn_mask = mix_attn_mask.unsqueeze(2)
+ attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
+ attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # 5. Output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class GlmImageTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int = 2560,
+ num_attention_heads: int = 64,
+ attention_head_dim: int = 40,
+ time_embed_dim: int = 512,
+ ) -> None:
+ super().__init__()
+
+ # 1. Attention
+ self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ out_dim=dim,
+ bias=True,
+ qk_norm="layer_norm",
+ elementwise_affine=False,
+ eps=1e-5,
+ processor=GlmImageAttnProcessor(),
+ )
+
+ # 2. Feedforward
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
+ ] = None,
+ attention_mask: Optional[Dict[str, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ kv_cache: Optional[GlmImageLayerKVCache] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Timestep conditioning
+ (
+ norm_hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ norm_encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ ) = self.norm1(hidden_states, encoder_hidden_states, temb)
+
+ # 2. Attention
+ attention_kwargs = attention_kwargs or {}
+
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ kv_cache=kv_cache,
+ **attention_kwargs,
+ )
+ hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
+
+ # 3. Feedforward
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
+ 1 + c_scale_mlp.unsqueeze(1)
+ ) + c_shift_mlp.unsqueeze(1)
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output_context = self.ff(norm_encoder_hidden_states)
+ hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
+
+ return hidden_states, encoder_hidden_states
+
+
+class GlmImageRotaryPosEmbed(nn.Module):
+ def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
+ super().__init__()
+
+ self.dim = dim
+ self.patch_size = patch_size
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, num_channels, height, width = hidden_states.shape
+ height, width = height // self.patch_size, width // self.patch_size
+
+ dim_h, dim_w = self.dim // 2, self.dim // 2
+ h_inv_freq = 1.0 / (
+ self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
+ )
+ w_inv_freq = 1.0 / (
+ self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
+ )
+ h_seq = torch.arange(height)
+ w_seq = torch.arange(width)
+ freqs_h = torch.outer(h_seq, h_inv_freq)
+ freqs_w = torch.outer(w_seq, w_inv_freq)
+
+ # Create position matrices for height and width
+ # [height, 1, dim//4] and [1, width, dim//4]
+ freqs_h = freqs_h.unsqueeze(1)
+ freqs_w = freqs_w.unsqueeze(0)
+ # Broadcast freqs_h and freqs_w to [height, width, dim//4]
+ freqs_h = freqs_h.expand(height, width, -1)
+ freqs_w = freqs_w.expand(height, width, -1)
+
+ # Concatenate along last dimension to get [height, width, dim//2]
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1)
+ freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
+ freqs = freqs.reshape(height * width, -1)
+ return (freqs.cos(), freqs.sin())
+
+
+class GlmImageAdaLayerNormContinuous(nn.Module):
+ """
+ GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
+ Linear on conditioning embedding.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ bias: bool = True,
+ norm_type: str = "layer_norm",
+ ):
+ super().__init__()
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ # *** NO SiLU here ***
+ emb = self.linear(conditioning_embedding.to(x.dtype))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
+class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
+ r"""
+ Args:
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, defaults to `40`):
+ The number of channels in each head.
+ num_attention_heads (`int`, defaults to `64`):
+ The number of heads to use for multi-head attention.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_embed_dim (`int`, defaults to `1472`):
+ Input dimension of text embeddings from the text encoder.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ condition_dim (`int`, defaults to `256`):
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
+ crop_coords).
+ pos_embed_max_size (`int`, defaults to `128`):
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
+ to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
+ means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
+ patch_size => 128 * 8 * 2 => 2048`.
+ sample_size (`int`, defaults to `128`):
+ The base resolution of input latents. If height/width is not provided during generation, this value is used
+ to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = [
+ "GlmImageTransformerBlock",
+ "GlmImageImageProjector",
+ "GlmImageImageProjector",
+ ]
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
+ _skip_keys = ["kv_caches"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ num_layers: int = 30,
+ attention_head_dim: int = 40,
+ num_attention_heads: int = 64,
+ text_embed_dim: int = 1472,
+ time_embed_dim: int = 512,
+ condition_dim: int = 256,
+ prior_vq_quantizer_codebook_size: int = 16384,
+ ):
+ super().__init__()
+
+ # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
+ # Each of these are sincos embeddings of shape 2 * condition_dim
+ pooled_projection_dim = 2 * 2 * condition_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels
+
+ # 1. RoPE
+ self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
+
+ # 2. Patch & Text-timestep embedding
+ self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
+ self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
+ self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
+ self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
+
+ self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
+ embedding_dim=time_embed_dim,
+ condition_dim=condition_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ timesteps_dim=time_embed_dim,
+ )
+
+ # 3. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output projection
+ self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ prior_token_id: torch.Tensor,
+ prior_token_drop: torch.Tensor,
+ timestep: torch.LongTensor,
+ target_size: torch.Tensor,
+ crop_coords: torch.Tensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ attention_mask: Optional[torch.Tensor] = None,
+ kv_caches: Optional[GlmImageKVCache] = None,
+ image_rotary_emb: Optional[
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
+ ] = None,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ batch_size, num_channels, height, width = hidden_states.shape
+
+ # 1. RoPE
+ if image_rotary_emb is None:
+ image_rotary_emb = self.rope(hidden_states)
+
+ # 2. Patch & Timestep embeddings
+ p = self.config.patch_size
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ hidden_states = self.image_projector(hidden_states)
+ encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
+ prior_embedding = self.prior_token_embedding(prior_token_id)
+ prior_embedding[prior_token_drop] *= 0.0
+ prior_hidden_states = self.prior_projector(prior_embedding)
+
+ hidden_states = hidden_states + prior_hidden_states
+
+ temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
+
+ # 3. Transformer blocks
+ for idx, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_mask,
+ attention_kwargs,
+ kv_caches[idx] if kv_caches is not None else None,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_mask,
+ attention_kwargs,
+ kv_cache=kv_caches[idx] if kv_caches is not None else None,
+ )
+
+ # 4. Output norm & projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
+
+ # Rearrange tensor from (B, H_p, W_p, C, p, p) to (B, C, H_p * p, W_p * p)
+ output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index fb0ce1a30ff9..4f0775ac9fa0 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -312,7 +312,6 @@ def forward(
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
- conditioning = timesteps_emb + pooled_projections
token_replace_emb = None
if self.image_condition_type == "token_replace":
@@ -324,8 +323,9 @@ def forward(
if self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
- conditioning = conditioning + guidance_emb
-
+ conditioning = timesteps_emb + guidance_emb + pooled_projections
+ else:
+ conditioning = timesteps_emb + pooled_projections
return conditioning, token_replace_emb
diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py
index 316e79da4fd6..57b28991d255 100644
--- a/src/diffusers/models/transformers/transformer_kandinsky.py
+++ b/src/diffusers/models/transformers/transformer_kandinsky.py
@@ -165,9 +165,8 @@ def __init__(self, model_dim, time_dim, max_period=10000.0):
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
- @torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
- args = torch.outer(time, self.freqs.to(device=time.device))
+ args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
@@ -269,7 +268,6 @@ def __init__(self, time_dim, model_dim, num_params):
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
- @torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x):
return self.out_layer(self.activation(x))
@@ -525,6 +523,7 @@ class Kandinsky5Transformer3DModel(
"Kandinsky5TransformerEncoderBlock",
"Kandinsky5TransformerDecoderBlock",
]
+ _keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
_supports_gradient_checkpointing = True
@register_to_config
diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py
index 7fbaaa3fee85..2696f5e78701 100644
--- a/src/diffusers/models/transformers/transformer_longcat_image.py
+++ b/src/diffusers/models/transformers/transformer_longcat_image.py
@@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import is_torch_npu_available, logging
+from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -499,11 +499,7 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
ids = torch.cat((txt_ids, img_ids), dim=0)
- if is_torch_npu_available():
- freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
- image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
- else:
- image_rotary_emb = self.pos_embed(ids)
+ image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py
new file mode 100644
index 000000000000..b88f096e8033
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_ltx2.py
@@ -0,0 +1,1350 @@
+# Copyright 2025 The Lightricks team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import (
+ USE_PEFT_BACKEND,
+ BaseOutput,
+ is_torch_version,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection
+from ..modeling_utils import ModelMixin
+from ..normalization import RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
+ cos, sin = freqs
+ x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+
+def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
+ cos, sin = freqs
+
+ x_dtype = x.dtype
+ needs_reshape = False
+ if x.ndim != 4 and cos.ndim == 4:
+ # cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
+ # The cos/sin batch dim may only be broadcastable, so take batch size from x
+ b = x.shape[0]
+ _, h, t, _ = cos.shape
+ x = x.reshape(b, t, h, -1).swapaxes(1, 2)
+ needs_reshape = True
+
+ # Split last dim (2*r) into (d=2, r)
+ last = x.shape[-1]
+ if last % 2 != 0:
+ raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.")
+ r = last // 2
+
+ # (..., 2, r)
+ split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
+ first_x = split_x[..., :1, :] # (..., 1, r)
+ second_x = split_x[..., 1:, :] # (..., 1, r)
+
+ cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
+ sin_u = sin.unsqueeze(-2)
+
+ out = split_x * cos_u
+ first_out = out[..., :1, :]
+ second_out = out[..., 1:, :]
+
+ first_out.addcmul_(-sin_u, second_x)
+ second_out.addcmul_(sin_u, first_x)
+
+ out = out.reshape(*out.shape[:-2], last)
+
+ if needs_reshape:
+ out = out.swapaxes(1, 2).reshape(b, t, -1)
+
+ out = out.to(dtype=x_dtype)
+ return out
+
+
+@dataclass
+class AudioVisualModelOutput(BaseOutput):
+ r"""
+ Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs.
+
+ Args:
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
+ The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output
+ of the model. This is typically a video (spatiotemporal) output.
+ audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`):
+ The audio output of the audiovisual model.
+ """
+
+ sample: "torch.Tensor" # noqa: F821
+ audio_sample: "torch.Tensor" # noqa: F821
+
+
+class LTX2AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0
+ model. In particular, the number of modulation parameters to be calculated is now configurable.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_mod_params (`int`, *optional*, defaults to `6`):
+ The number of modulation parameters which will be calculated in the first return argument. The default of 6
+ is standard, but sometimes we may want to have a different (usually smaller) number of modulation
+ parameters.
+ use_additional_conditions (`bool`, *optional*, defaults to `False`):
+ Whether to use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False):
+ super().__init__()
+ self.num_mod_params = num_mod_params
+
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ batch_size: Optional[int] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
+class LTX2AudioVideoAttnProcessor:
+ r"""
+ Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
+ Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
+ support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if is_torch_version("<", "2.0"):
+ raise ValueError(
+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
+ )
+
+ def __call__(
+ self,
+ attn: "LTX2Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if query_rotary_emb is not None:
+ if attn.rope_type == "interleaved":
+ query = apply_interleaved_rotary_emb(query, query_rotary_emb)
+ key = apply_interleaved_rotary_emb(
+ key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
+ )
+ elif attn.rope_type == "split":
+ query = apply_split_rotary_emb(query, query_rotary_emb)
+ key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
+ r"""
+ Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
+ RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
+ """
+
+ _default_processor_cls = LTX2AudioVideoAttnProcessor
+ _available_processors = [LTX2AudioVideoAttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ kv_heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = True,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ qk_norm: str = "rms_norm_across_heads",
+ norm_eps: float = 1e-6,
+ norm_elementwise_affine: bool = True,
+ rope_type: str = "interleaved",
+ processor=None,
+ ):
+ super().__init__()
+ if qk_norm != "rms_norm_across_heads":
+ raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
+
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = query_dim
+ self.heads = heads
+ self.rope_type = rope_type
+
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ hidden_states = self.processor(
+ self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs
+ )
+ return hidden_states
+
+
+class LTX2VideoTransformerBlock(nn.Module):
+ r"""
+ Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video).
+
+ Args:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ qk_norm (`str`, defaults to `"rms_norm"`):
+ The normalization layer to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ cross_attention_dim: int,
+ audio_dim: int,
+ audio_num_attention_heads: int,
+ audio_attention_head_dim,
+ audio_cross_attention_dim: int,
+ qk_norm: str = "rms_norm_across_heads",
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = True,
+ attention_out_bias: bool = True,
+ eps: float = 1e-6,
+ elementwise_affine: bool = False,
+ rope_type: str = "interleaved",
+ ):
+ super().__init__()
+
+ # 1. Self-Attention (video and audio)
+ self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.attn1 = LTX2Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ kv_heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ rope_type=rope_type,
+ )
+
+ self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.audio_attn1 = LTX2Attention(
+ query_dim=audio_dim,
+ heads=audio_num_attention_heads,
+ kv_heads=audio_num_attention_heads,
+ dim_head=audio_attention_head_dim,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ rope_type=rope_type,
+ )
+
+ # 2. Prompt Cross-Attention
+ self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.attn2 = LTX2Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ kv_heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ rope_type=rope_type,
+ )
+
+ self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.audio_attn2 = LTX2Attention(
+ query_dim=audio_dim,
+ cross_attention_dim=audio_cross_attention_dim,
+ heads=audio_num_attention_heads,
+ kv_heads=audio_num_attention_heads,
+ dim_head=audio_attention_head_dim,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ rope_type=rope_type,
+ )
+
+ # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
+ # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio
+ self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.audio_to_video_attn = LTX2Attention(
+ query_dim=dim,
+ cross_attention_dim=audio_dim,
+ heads=audio_num_attention_heads,
+ kv_heads=audio_num_attention_heads,
+ dim_head=audio_attention_head_dim,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ rope_type=rope_type,
+ )
+
+ # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
+ self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.video_to_audio_attn = LTX2Attention(
+ query_dim=audio_dim,
+ cross_attention_dim=dim,
+ heads=audio_num_attention_heads,
+ kv_heads=audio_num_attention_heads,
+ dim_head=audio_attention_head_dim,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ rope_type=rope_type,
+ )
+
+ # 4. Feedforward layers
+ self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.ff = FeedForward(dim, activation_fn=activation_fn)
+
+ self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
+
+ # 5. Per-Layer Modulation Parameters
+ # Self-Attention / Feedforward AdaLayerNorm-Zero mod params
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+ self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
+
+ # Per-layer a2v, v2a Cross-Attention mod params
+ self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
+ self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ audio_hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ audio_encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ temb_audio: torch.Tensor,
+ temb_ca_scale_shift: torch.Tensor,
+ temb_ca_audio_scale_shift: torch.Tensor,
+ temb_ca_gate: torch.Tensor,
+ temb_ca_audio_gate: torch.Tensor,
+ video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ audio_encoder_attention_mask: Optional[torch.Tensor] = None,
+ a2v_cross_attention_mask: Optional[torch.Tensor] = None,
+ v2a_cross_attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.size(0)
+
+ # 1. Video and Audio Self-Attention
+ norm_hidden_states = self.norm1(hidden_states)
+
+ num_ada_params = self.scale_shift_table.shape[0]
+ ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
+ batch_size, temb.size(1), num_ada_params, -1
+ )
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+
+ attn_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ query_rotary_emb=video_rotary_emb,
+ )
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
+
+ norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
+
+ num_audio_ada_params = self.audio_scale_shift_table.shape[0]
+ audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
+ batch_size, temb_audio.size(1), num_audio_ada_params, -1
+ )
+ audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
+ audio_ada_values.unbind(dim=2)
+ )
+ norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
+
+ attn_audio_hidden_states = self.audio_attn1(
+ hidden_states=norm_audio_hidden_states,
+ encoder_hidden_states=None,
+ query_rotary_emb=audio_rotary_emb,
+ )
+ audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
+
+ # 2. Video and Audio Cross-Attention with the text embeddings
+ norm_hidden_states = self.norm2(hidden_states)
+ attn_hidden_states = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ query_rotary_emb=None,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = hidden_states + attn_hidden_states
+
+ norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
+ attn_audio_hidden_states = self.audio_attn2(
+ norm_audio_hidden_states,
+ encoder_hidden_states=audio_encoder_hidden_states,
+ query_rotary_emb=None,
+ attention_mask=audio_encoder_attention_mask,
+ )
+ audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
+
+ # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
+ norm_hidden_states = self.audio_to_video_norm(hidden_states)
+ norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
+
+ # Combine global and per-layer cross attention modulation parameters
+ # Video
+ video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
+ video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
+
+ video_ca_scale_shift_table = (
+ video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
+ + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
+ ).unbind(dim=2)
+ video_ca_gate = (
+ video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
+ + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
+ ).unbind(dim=2)
+
+ video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
+ a2v_gate = video_ca_gate[0].squeeze(2)
+
+ # Audio
+ audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
+ audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
+
+ audio_ca_scale_shift_table = (
+ audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
+ + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
+ ).unbind(dim=2)
+ audio_ca_gate = (
+ audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
+ + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
+ ).unbind(dim=2)
+
+ audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
+ v2a_gate = audio_ca_gate[0].squeeze(2)
+
+ # Audio-to-Video Cross Attention: Q: Video; K,V: Audio
+ mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
+ 2
+ )
+ mod_norm_audio_hidden_states = norm_audio_hidden_states * (
+ 1 + audio_a2v_ca_scale.squeeze(2)
+ ) + audio_a2v_ca_shift.squeeze(2)
+
+ a2v_attn_hidden_states = self.audio_to_video_attn(
+ mod_norm_hidden_states,
+ encoder_hidden_states=mod_norm_audio_hidden_states,
+ query_rotary_emb=ca_video_rotary_emb,
+ key_rotary_emb=ca_audio_rotary_emb,
+ attention_mask=a2v_cross_attention_mask,
+ )
+
+ hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
+
+ # Video-to-Audio Cross Attention: Q: Audio; K,V: Video
+ mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
+ 2
+ )
+ mod_norm_audio_hidden_states = norm_audio_hidden_states * (
+ 1 + audio_v2a_ca_scale.squeeze(2)
+ ) + audio_v2a_ca_shift.squeeze(2)
+
+ v2a_attn_hidden_states = self.video_to_audio_attn(
+ mod_norm_audio_hidden_states,
+ encoder_hidden_states=mod_norm_hidden_states,
+ query_rotary_emb=ca_audio_rotary_emb,
+ key_rotary_emb=ca_video_rotary_emb,
+ attention_mask=v2a_cross_attention_mask,
+ )
+
+ audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
+
+ # 4. Feedforward
+ norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp
+ audio_ff_output = self.audio_ff(norm_audio_hidden_states)
+ audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp
+
+ return hidden_states, audio_hidden_states
+
+
+class LTX2AudioVideoRotaryPosEmbed(nn.Module):
+ """
+ Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model.
+
+ Args:
+ causal_offset (`int`, *optional*, defaults to `1`):
+ Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE
+ treats the very first frame differently), but could also be 0 (for non-causal modeling).
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ patch_size: int = 1,
+ patch_size_t: int = 1,
+ base_num_frames: int = 20,
+ base_height: int = 2048,
+ base_width: int = 2048,
+ sampling_rate: int = 16000,
+ hop_length: int = 160,
+ scale_factors: Tuple[int, ...] = (8, 32, 32),
+ theta: float = 10000.0,
+ causal_offset: int = 1,
+ modality: str = "video",
+ double_precision: bool = True,
+ rope_type: str = "interleaved",
+ num_attention_heads: int = 32,
+ ) -> None:
+ super().__init__()
+
+ self.dim = dim
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+
+ if rope_type not in ["interleaved", "split"]:
+ raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
+ self.rope_type = rope_type
+
+ self.base_num_frames = base_num_frames
+ self.num_attention_heads = num_attention_heads
+
+ # Video-specific
+ self.base_height = base_height
+ self.base_width = base_width
+
+ # Audio-specific
+ self.sampling_rate = sampling_rate
+ self.hop_length = hop_length
+ self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0])
+
+ self.scale_factors = scale_factors
+ self.theta = theta
+ self.causal_offset = causal_offset
+
+ self.modality = modality
+ if self.modality not in ["video", "audio"]:
+ raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.")
+ self.double_precision = double_precision
+
+ def prepare_video_coords(
+ self,
+ batch_size: int,
+ num_frames: int,
+ height: int,
+ width: int,
+ device: torch.device,
+ fps: float = 24.0,
+ ) -> torch.Tensor:
+ """
+ Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel
+ space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2)
+ where
+ - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames)
+ - axis 3 (size 2) stores `[start, end)` indices within each dimension
+
+ Args:
+ batch_size (`int`):
+ Batch size of the video latents.
+ num_frames (`int`):
+ Number of latent frames in the video latents.
+ height (`int`):
+ Latent height of the video latents.
+ width (`int`):
+ Latent width of the video latents.
+ device (`torch.device`):
+ Device on which to create the video grid.
+
+ Returns:
+ `torch.Tensor`:
+ Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2].
+ """
+
+ # 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width)
+ # Always compute rope in fp32
+ grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device)
+ grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device)
+ grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device)
+ # indexing='ij' ensures that the dimensions are kept in order as (frames, height, width)
+ grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
+ grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches
+
+ # 2. Get the patch boundaries with respect to the latent video grid
+ patch_size = (self.patch_size_t, self.patch_size, self.patch_size)
+ patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device)
+ patch_ends = grid + patch_size_delta.view(3, 1, 1, 1)
+
+ # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension
+ latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2]
+ # Reshape to (batch_size, 3, num_patches, 2)
+ latent_coords = latent_coords.flatten(1, 3)
+ latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1)
+
+ # 3. Calculate the pixel space patch boundaries from the latent boundaries.
+ scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device)
+ # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape
+ broadcast_shape = [1] * latent_coords.ndim
+ broadcast_shape[1] = -1 # This is the (frame, height, width) dim
+ # Apply per-axis scaling to convert latent coordinates to pixel space coordinates
+ pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape)
+
+ # As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift
+ # and clamp to keep the first-frame timestamps causal and non-negative.
+ pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0)
+
+ # Scale the temporal coordinates by the video FPS
+ pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
+
+ return pixel_coords
+
+ def prepare_audio_coords(
+ self,
+ batch_size: int,
+ num_frames: int,
+ device: torch.device,
+ shift: int = 0,
+ ) -> torch.Tensor:
+ """
+ Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame.
+ This will ultimately have shape (batch_size, 3, num_patches, 2) where
+ - axis 1 (size 1) represents the temporal dimension
+ - axis 3 (size 2) stores `[start, end)` indices within each dimension
+
+ Args:
+ batch_size (`int`):
+ Batch size of the audio latents.
+ num_frames (`int`):
+ Number of latent frames in the audio latents.
+ device (`torch.device`):
+ Device on which to create the audio grid.
+ shift (`int`, *optional*, defaults to `0`):
+ Offset on the latent indices. Different shift values correspond to different overlapping windows with
+ respect to the same underlying latent grid.
+
+ Returns:
+ `torch.Tensor`:
+ Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2].
+ """
+
+ # 1. Generate coordinates in the frame (time) dimension.
+ # Always compute rope in fp32
+ grid_f = torch.arange(
+ start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device
+ )
+
+ # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid
+ audio_scale_factor = self.scale_factors[0]
+ # Scale back to mel spectrogram space
+ grid_start_mel = grid_f * audio_scale_factor
+ # Handle first frame causal offset, ensuring non-negative timestamps
+ grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0)
+ # Convert mel bins back into seconds
+ grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
+
+ # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid
+ grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
+ grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0)
+ grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
+
+ audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2]
+ audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2]
+ audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2]
+ return audio_coords
+
+ def prepare_coords(self, *args, **kwargs):
+ if self.modality == "video":
+ return self.prepare_video_coords(*args, **kwargs)
+ elif self.modality == "audio":
+ return self.prepare_audio_coords(*args, **kwargs)
+
+ def forward(
+ self, coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or coords.device
+
+ # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn)
+ num_pos_dims = coords.shape[1]
+
+ # 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch
+ # position index
+ if coords.ndim == 4:
+ coords_start, coords_end = coords.chunk(2, dim=-1)
+ coords = (coords_start + coords_end) / 2.0
+ coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches]
+
+ # 2. Get coordinates as a fraction of the base data shape
+ if self.modality == "video":
+ max_positions = (self.base_num_frames, self.base_height, self.base_width)
+ elif self.modality == "audio":
+ max_positions = (self.base_num_frames,)
+ # [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims]
+ grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device)
+ # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin
+ num_rope_elems = num_pos_dims * 2
+
+ # 3. Create a 1D grid of frequencies for RoPE
+ freqs_dtype = torch.float64 if self.double_precision else torch.float32
+ pow_indices = torch.pow(
+ self.theta,
+ torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
+ )
+ freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
+
+ # 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
+ # (self.dim // num_elems,)
+ freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems]
+ freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2]
+
+ # 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim
+ # TODO: consider implementing this as a utility and reuse in `connectors.py`.
+ # src/diffusers/pipelines/ltx2/connectors.py
+ if self.rope_type == "interleaved":
+ cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
+ sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
+
+ if self.dim % num_rope_elems != 0:
+ cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
+ sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems])
+ cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
+ sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
+
+ elif self.rope_type == "split":
+ expected_freqs = self.dim // 2
+ current_freqs = freqs.shape[-1]
+ pad_size = expected_freqs - current_freqs
+ cos_freq = freqs.cos()
+ sin_freq = freqs.sin()
+
+ if pad_size != 0:
+ cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
+ sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
+
+ cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
+ sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
+
+ # Reshape freqs to be compatible with multi-head attention
+ b = cos_freq.shape[0]
+ t = cos_freq.shape[1]
+
+ cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
+ sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
+
+ cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
+ sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
+
+ return cos_freqs, sin_freqs
+
+
+class LTX2VideoTransformer3DModel(
+ ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
+):
+ r"""
+ A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
+
+ Args:
+ in_channels (`int`, defaults to `128`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `128`):
+ The number of channels in the output.
+ patch_size (`int`, defaults to `1`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the tmeporal patches to use in the patch embedding layer.
+ num_attention_heads (`int`, defaults to `32`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ cross_attention_dim (`int`, defaults to `2048 `):
+ The number of channels for cross attention heads.
+ num_layers (`int`, defaults to `28`):
+ The number of layers of Transformer blocks to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
+ The normalization layer to use.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["norm"]
+ _repeated_blocks = ["LTX2VideoTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 128, # Video Arguments
+ out_channels: Optional[int] = 128,
+ patch_size: int = 1,
+ patch_size_t: int = 1,
+ num_attention_heads: int = 32,
+ attention_head_dim: int = 128,
+ cross_attention_dim: int = 4096,
+ vae_scale_factors: Tuple[int, int, int] = (8, 32, 32),
+ pos_embed_max_pos: int = 20,
+ base_height: int = 2048,
+ base_width: int = 2048,
+ audio_in_channels: int = 128, # Audio Arguments
+ audio_out_channels: Optional[int] = 128,
+ audio_patch_size: int = 1,
+ audio_patch_size_t: int = 1,
+ audio_num_attention_heads: int = 32,
+ audio_attention_head_dim: int = 64,
+ audio_cross_attention_dim: int = 2048,
+ audio_scale_factor: int = 4,
+ audio_pos_embed_max_pos: int = 20,
+ audio_sampling_rate: int = 16000,
+ audio_hop_length: int = 160,
+ num_layers: int = 48, # Shared arguments
+ activation_fn: str = "gelu-approximate",
+ qk_norm: str = "rms_norm_across_heads",
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ caption_channels: int = 3840,
+ attention_bias: bool = True,
+ attention_out_bias: bool = True,
+ rope_theta: float = 10000.0,
+ rope_double_precision: bool = True,
+ causal_offset: int = 1,
+ timestep_scale_multiplier: int = 1000,
+ cross_attn_timestep_scale_multiplier: int = 1000,
+ rope_type: str = "interleaved",
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ audio_out_channels = audio_out_channels or audio_in_channels
+ inner_dim = num_attention_heads * attention_head_dim
+ audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
+
+ # 1. Patchification input projections
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
+
+ # 2. Prompt embeddings
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+ self.audio_caption_projection = PixArtAlphaTextProjection(
+ in_features=caption_channels, hidden_size=audio_inner_dim
+ )
+
+ # 3. Timestep Modulation Params and Embedding
+ # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding
+ # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters
+ self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False)
+ self.audio_time_embed = LTX2AdaLayerNormSingle(
+ audio_inner_dim, num_mod_params=6, use_additional_conditions=False
+ )
+
+ # 3.2. Global Cross Attention Modulation Parameters
+ # Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params,
+ # which are then further modified by per-block modulaton params in each transformer block.
+ # There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and
+ # video-to-audio (v2a) cross attention
+ self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle(
+ inner_dim, num_mod_params=4, use_additional_conditions=False
+ )
+ self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle(
+ audio_inner_dim, num_mod_params=4, use_additional_conditions=False
+ )
+ # Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys
+ # and values (KV))
+ self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle(
+ inner_dim, num_mod_params=1, use_additional_conditions=False
+ )
+ # Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys
+ # and values (KV))
+ self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle(
+ audio_inner_dim, num_mod_params=1, use_additional_conditions=False
+ )
+
+ # 3.3. Output Layer Scale/Shift Modulation parameters
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5)
+
+ # 4. Rotary Positional Embeddings (RoPE)
+ # Self-Attention
+ self.rope = LTX2AudioVideoRotaryPosEmbed(
+ dim=inner_dim,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ base_num_frames=pos_embed_max_pos,
+ base_height=base_height,
+ base_width=base_width,
+ scale_factors=vae_scale_factors,
+ theta=rope_theta,
+ causal_offset=causal_offset,
+ modality="video",
+ double_precision=rope_double_precision,
+ rope_type=rope_type,
+ num_attention_heads=num_attention_heads,
+ )
+ self.audio_rope = LTX2AudioVideoRotaryPosEmbed(
+ dim=audio_inner_dim,
+ patch_size=audio_patch_size,
+ patch_size_t=audio_patch_size_t,
+ base_num_frames=audio_pos_embed_max_pos,
+ sampling_rate=audio_sampling_rate,
+ hop_length=audio_hop_length,
+ scale_factors=[audio_scale_factor],
+ theta=rope_theta,
+ causal_offset=causal_offset,
+ modality="audio",
+ double_precision=rope_double_precision,
+ rope_type=rope_type,
+ num_attention_heads=audio_num_attention_heads,
+ )
+
+ # Audio-to-Video, Video-to-Audio Cross-Attention
+ cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos)
+ self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed(
+ dim=audio_cross_attention_dim,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ base_num_frames=cross_attn_pos_embed_max_pos,
+ base_height=base_height,
+ base_width=base_width,
+ theta=rope_theta,
+ causal_offset=causal_offset,
+ modality="video",
+ double_precision=rope_double_precision,
+ rope_type=rope_type,
+ num_attention_heads=num_attention_heads,
+ )
+ self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed(
+ dim=audio_cross_attention_dim,
+ patch_size=audio_patch_size,
+ patch_size_t=audio_patch_size_t,
+ base_num_frames=cross_attn_pos_embed_max_pos,
+ sampling_rate=audio_sampling_rate,
+ hop_length=audio_hop_length,
+ theta=rope_theta,
+ causal_offset=causal_offset,
+ modality="audio",
+ double_precision=rope_double_precision,
+ rope_type=rope_type,
+ num_attention_heads=audio_num_attention_heads,
+ )
+
+ # 5. Transformer Blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ LTX2VideoTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ audio_dim=audio_inner_dim,
+ audio_num_attention_heads=audio_num_attention_heads,
+ audio_attention_head_dim=audio_attention_head_dim,
+ audio_cross_attention_dim=audio_cross_attention_dim,
+ qk_norm=qk_norm,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ attention_out_bias=attention_out_bias,
+ eps=norm_eps,
+ elementwise_affine=norm_elementwise_affine,
+ rope_type=rope_type,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 6. Output layers
+ self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels)
+
+ self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False)
+ self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ audio_hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ audio_encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ audio_timestep: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ audio_encoder_attention_mask: Optional[torch.Tensor] = None,
+ num_frames: Optional[int] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ fps: float = 24.0,
+ audio_num_frames: Optional[int] = None,
+ video_coords: Optional[torch.Tensor] = None,
+ audio_coords: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> torch.Tensor:
+ """
+ Forward pass for LTX-2.0 audiovisual video transformer.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`.
+ audio_hidden_states (`torch.Tensor`):
+ Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`.
+ audio_encoder_hidden_states (`torch.Tensor`):
+ Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`.
+ timestep (`torch.Tensor`):
+ Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by
+ `self.config.timestep_scale_multiplier`.
+ audio_timestep (`torch.Tensor`, *optional*):
+ Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
+ params. This is only used by certain pipelines such as the I2V pipeline.
+ encoder_attention_mask (`torch.Tensor`, *optional*):
+ Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
+ audio_encoder_attention_mask (`torch.Tensor`, *optional*):
+ Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
+ num_frames (`int`, *optional*):
+ The number of latent video frames. Used if calculating the video coordinates for RoPE.
+ height (`int`, *optional*):
+ The latent video height. Used if calculating the video coordinates for RoPE.
+ width (`int`, *optional*):
+ The latent video width. Used if calculating the video coordinates for RoPE.
+ fps: (`float`, *optional*, defaults to `24.0`):
+ The desired frames per second of the generated video. Used if calculating the video coordinates for
+ RoPE.
+ audio_num_frames: (`int`, *optional*):
+ The number of latent audio frames. Used if calculating the audio coordinates for RoPE.
+ video_coords (`torch.Tensor`, *optional*):
+ The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
+ `(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
+ audio_coords (`torch.Tensor`, *optional*):
+ The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
+ `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
+ attention_kwargs (`Dict[str, Any]`, *optional*):
+ Optional dict of keyword args to be passed to the attention processor.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple.
+
+ Returns:
+ `AudioVisualModelOutput` or `tuple`:
+ If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a
+ `tuple` is returned where the first element is the denoised video latent patch sequence and the second
+ element is the denoised audio latent patch sequence.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # Determine timestep for audio.
+ audio_timestep = audio_timestep if audio_timestep is not None else timestep
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2:
+ audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
+ audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
+
+ batch_size = hidden_states.size(0)
+
+ # 1. Prepare RoPE positional embeddings
+ if video_coords is None:
+ video_coords = self.rope.prepare_video_coords(
+ batch_size, num_frames, height, width, hidden_states.device, fps=fps
+ )
+ if audio_coords is None:
+ audio_coords = self.audio_rope.prepare_audio_coords(
+ batch_size, audio_num_frames, audio_hidden_states.device
+ )
+
+ video_rotary_emb = self.rope(video_coords, device=hidden_states.device)
+ audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
+
+ video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)
+ audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(
+ audio_coords[:, 0:1, :], device=audio_hidden_states.device
+ )
+
+ # 2. Patchify input projections
+ hidden_states = self.proj_in(hidden_states)
+ audio_hidden_states = self.audio_proj_in(audio_hidden_states)
+
+ # 3. Prepare timestep embeddings and modulation parameters
+ timestep_cross_attn_gate_scale_factor = (
+ self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
+ )
+
+ # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
+ # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
+ # modulation with scale_shift_table (and similarly for audio)
+ temb, embedded_timestep = self.time_embed(
+ timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(batch_size, -1, temb.size(-1))
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
+
+ temb_audio, audio_embedded_timestep = self.audio_time_embed(
+ audio_timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=audio_hidden_states.dtype,
+ )
+ temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
+ audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
+
+ # 3.2. Prepare global modality cross attention modulation parameters
+ video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
+ timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+ video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
+ timestep.flatten() * timestep_cross_attn_gate_scale_factor,
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+ video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(
+ batch_size, -1, video_cross_attn_scale_shift.shape[-1]
+ )
+ video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
+
+ audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
+ audio_timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=audio_hidden_states.dtype,
+ )
+ audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
+ audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
+ batch_size=batch_size,
+ hidden_dtype=audio_hidden_states.dtype,
+ )
+ audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(
+ batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
+ )
+ audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
+
+ # 4. Prepare prompt embeddings
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
+
+ audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
+ audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
+
+ # 5. Run transformer blocks
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ audio_hidden_states,
+ encoder_hidden_states,
+ audio_encoder_hidden_states,
+ temb,
+ temb_audio,
+ video_cross_attn_scale_shift,
+ audio_cross_attn_scale_shift,
+ video_cross_attn_a2v_gate,
+ audio_cross_attn_v2a_gate,
+ video_rotary_emb,
+ audio_rotary_emb,
+ video_cross_attn_rotary_emb,
+ audio_cross_attn_rotary_emb,
+ encoder_attention_mask,
+ audio_encoder_attention_mask,
+ )
+ else:
+ hidden_states, audio_hidden_states = block(
+ hidden_states=hidden_states,
+ audio_hidden_states=audio_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
+ temb=temb,
+ temb_audio=temb_audio,
+ temb_ca_scale_shift=video_cross_attn_scale_shift,
+ temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
+ temb_ca_gate=video_cross_attn_a2v_gate,
+ temb_ca_audio_gate=audio_cross_attn_v2a_gate,
+ video_rotary_emb=video_rotary_emb,
+ audio_rotary_emb=audio_rotary_emb,
+ ca_video_rotary_emb=video_cross_attn_rotary_emb,
+ ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
+ encoder_attention_mask=encoder_attention_mask,
+ audio_encoder_attention_mask=audio_encoder_attention_mask,
+ )
+
+ # 6. Output layers (including unpatchification)
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = hidden_states * (1 + scale) + shift
+ output = self.proj_out(hidden_states)
+
+ audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None]
+ audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1]
+
+ audio_hidden_states = self.audio_norm_out(audio_hidden_states)
+ audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
+ audio_output = self.audio_proj_out(audio_hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output, audio_output)
+ return AudioVisualModelOutput(sample=output, audio_sample=audio_output)
diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py
index 0a09aa720b3f..139ceaefa4e9 100644
--- a/src/diffusers/models/transformers/transformer_ovis_image.py
+++ b/src/diffusers/models/transformers/transformer_ovis_image.py
@@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import is_torch_npu_available, logging
+from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -530,11 +530,7 @@ def forward(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
- if is_torch_npu_available():
- freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
- image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
- else:
- image_rotary_emb = self.pos_embed(ids)
+ image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py
index 1229bab169b2..cf11d8e01fb4 100644
--- a/src/diffusers/models/transformers/transformer_qwenimage.py
+++ b/src/diffusers/models/transformers/transformer_qwenimage.py
@@ -24,7 +24,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
@@ -142,6 +142,32 @@ def apply_rotary_emb_qwen(
return x_out.type_as(x)
+def compute_text_seq_len_from_mask(
+ encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor]
+) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
+ """
+ batch_size, text_seq_len = encoder_hidden_states.shape[:2]
+ if encoder_hidden_states_mask is None:
+ return text_seq_len, None, None
+
+ if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
+ raise ValueError(
+ f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
+ f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
+ )
+
+ if encoder_hidden_states_mask.dtype != torch.bool:
+ encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
+
+ position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
+ active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
+ has_active = encoder_hidden_states_mask.any(dim=1)
+ per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
+ return text_seq_len, per_sample_len, encoder_hidden_states_mask
+
+
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, use_additional_t_cond=False):
super().__init__()
@@ -207,21 +233,50 @@ def rope_params(self, index, dim, theta=10000):
def forward(
self,
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
- txt_seq_lens: List[int],
- device: torch.device,
+ txt_seq_lens: Optional[List[int]] = None,
+ device: torch.device = None,
+ max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
- txt_seq_lens (`List[int]`):
- A list of integers of length batch_size representing the length of each text prompt.
- device: (`torch.device`):
+ txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
+ Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used.
+ device: (`torch.device`, *optional*):
The device on which to perform the RoPE computation.
+ max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
+ The maximum text sequence length for RoPE computation. This should match the encoder hidden states
+ sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
"""
- if self.pos_freqs.device != device:
- self.pos_freqs = self.pos_freqs.to(device)
- self.neg_freqs = self.neg_freqs.to(device)
+ # Handle deprecated txt_seq_lens parameter
+ if txt_seq_lens is not None:
+ deprecate(
+ "txt_seq_lens",
+ "0.39.0",
+ "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
+ "Please use `max_txt_seq_len` instead. "
+ "The new parameter accepts a single int or tensor value representing the maximum text sequence length.",
+ standard_warn=False,
+ )
+ if max_txt_seq_len is None:
+ # Use max of txt_seq_lens for backward compatibility
+ max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens
+
+ if max_txt_seq_len is None:
+ raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.")
+
+ # Validate batch inference with variable-sized images
+ if isinstance(video_fhw, list) and len(video_fhw) > 1:
+ # Check if all instances have the same size
+ first_fhw = video_fhw[0]
+ if not all(fhw == first_fhw for fhw in video_fhw):
+ logger.warning(
+ "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. "
+ "All images in the batch should have the same dimensions (frame, height, width). "
+ f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
+ "for RoPE computation, which may lead to incorrect results for other images in the batch."
+ )
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
@@ -233,8 +288,7 @@ def forward(
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
- video_freq = self._compute_video_freqs(frame, height, width, idx)
- video_freq = video_freq.to(device)
+ video_freq = self._compute_video_freqs(frame, height, width, idx, device)
vid_freqs.append(video_freq)
if self.scale_rope:
@@ -242,17 +296,23 @@ def forward(
else:
max_vid_index = max(height, width, max_vid_index)
- max_len = max(txt_seq_lens)
- txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
+ max_txt_seq_len_int = int(max_txt_seq_len)
+ # Create device-specific copy for text freqs without modifying self.pos_freqs
+ txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=128)
- def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
+ def _compute_video_freqs(
+ self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
+ ) -> torch.Tensor:
seq_lens = frame * height * width
- freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
- freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
+ neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
+
+ freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
@@ -304,14 +364,35 @@ def rope_params(self, index, dim, theta=10000):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
- def forward(self, video_fhw, txt_seq_lens, device):
+ def forward(
+ self,
+ video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
+ max_txt_seq_len: Union[int, torch.Tensor],
+ device: torch.device = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
"""
- Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
- txt_length: [bs] a list of 1 integers representing the length of the text
+ Args:
+ video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
+ A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer
+ structures.
+ max_txt_seq_len (`int` or `torch.Tensor`):
+ The maximum text sequence length for RoPE computation. This should match the encoder hidden states
+ sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
+ device: (`torch.device`, *optional*):
+ The device on which to perform the RoPE computation.
"""
- if self.pos_freqs.device != device:
- self.pos_freqs = self.pos_freqs.to(device)
- self.neg_freqs = self.neg_freqs.to(device)
+ # Validate batch inference with variable-sized images
+ # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
+ if isinstance(video_fhw, list) and len(video_fhw) > 1:
+ # Check if this is batch inference (list of layer lists/tuples)
+ first_entry = video_fhw[0]
+ if not all(entry == first_entry for entry in video_fhw):
+ logger.warning(
+ "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. "
+ "All images in the batch should have the same layer structure. "
+ f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} "
+ "for RoPE computation, which may lead to incorrect results for other images in the batch."
+ )
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
@@ -324,11 +405,10 @@ def forward(self, video_fhw, txt_seq_lens, device):
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
- video_freq = self._compute_video_freqs(frame, height, width, idx)
+ video_freq = self._compute_video_freqs(frame, height, width, idx, device)
else:
### For the condition image, we set the layer index to -1
- video_freq = self._compute_condition_freqs(frame, height, width)
- video_freq = video_freq.to(device)
+ video_freq = self._compute_condition_freqs(frame, height, width, device)
vid_freqs.append(video_freq)
if self.scale_rope:
@@ -337,17 +417,21 @@ def forward(self, video_fhw, txt_seq_lens, device):
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
- max_len = max(txt_seq_lens)
- txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
+ max_txt_seq_len_int = int(max_txt_seq_len)
+ # Create device-specific copy for text freqs without modifying self.pos_freqs
+ txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
- def _compute_video_freqs(self, frame, height, width, idx=0):
+ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
- freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
- freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
+ neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
+
+ freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
@@ -363,10 +447,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
- def _compute_condition_freqs(self, frame, height, width):
+ def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
- freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
- freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
+ neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
+
+ freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
@@ -454,7 +541,6 @@ def __call__(
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
- # Compute joint attention
joint_hidden_states = dispatch_attention_fn(
joint_query,
joint_key,
@@ -675,11 +761,14 @@ class QwenImageTransformer2DModel(
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
+ # Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
_cp_plan = {
- "": {
+ "transformer_blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
- "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "transformer_blocks.*": {
+ "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
@@ -762,14 +851,25 @@ def forward(
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
- Mask of the input conditions.
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
+ Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
+ Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
+ (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
+ img_shapes (`List[Tuple[int, int, int]]`, *optional*):
+ Image shapes for RoPE computation.
+ txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
+ Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be
+ used to compute RoPE sequence length.
+ guidance (`torch.Tensor`, *optional*):
+ Guidance tensor for conditional generation.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_block_samples (*optional*):
+ ControlNet block samples to add to the transformer blocks.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
@@ -778,6 +878,15 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
+ if txt_seq_lens is not None:
+ deprecate(
+ "txt_seq_lens",
+ "0.39.0",
+ "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
+ "Please use `encoder_hidden_states_mask` instead. "
+ "The mask-based approach is more flexible and supports variable-length sequences.",
+ standard_warn=False,
+ )
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -810,6 +919,11 @@ def forward(
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
+ # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
+ text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
+ encoder_hidden_states, encoder_hidden_states_mask
+ )
+
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
@@ -819,7 +933,17 @@ def forward(
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
)
- image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
+ image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
+
+ # Construct joint attention mask once to avoid reconstructing in every block
+ # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
+ block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
+ if encoder_hidden_states_mask is not None:
+ # Build joint mask: [text_mask, all_ones_for_image]
+ batch_size, image_seq_len = hidden_states.shape[:2]
+ image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
+ joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
+ block_attention_kwargs["attention_mask"] = joint_attention_mask
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -827,10 +951,10 @@ def forward(
block,
hidden_states,
encoder_hidden_states,
- encoder_hidden_states_mask,
+ None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
temb,
image_rotary_emb,
- attention_kwargs,
+ block_attention_kwargs,
modulate_index,
)
@@ -838,10 +962,10 @@ def forward(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
- encoder_hidden_states_mask=encoder_hidden_states_mask,
+ encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
temb=temb,
image_rotary_emb=image_rotary_emb,
- joint_attention_kwargs=attention_kwargs,
+ joint_attention_kwargs=block_attention_kwargs,
modulate_index=modulate_index,
)
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index f7693ec5d3ac..132f615f2199 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -134,7 +134,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
- parallel_config=self._parallel_config,
+ # Reference: https://github.com/huggingface/diffusers/pull/12909
+ parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -147,7 +148,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
- parallel_config=self._parallel_config,
+ # Reference: https://github.com/huggingface/diffusers/pull/12909
+ parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
@@ -552,9 +554,11 @@ class WanTransformer3DModel(
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
- "blocks.*": {
- "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
- },
+ # Reference: https://github.com/huggingface/diffusers/pull/12909
+ # We need to disable the splitting of encoder_hidden_states because the image_encoder
+ # (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
+ # of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
+ # —to be indivisible by the number of devices in the CP.
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
"": {
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py
index 6a47a67385a3..8860f4bca91f 100644
--- a/src/diffusers/models/transformers/transformer_wan_animate.py
+++ b/src/diffusers/models/transformers/transformer_wan_animate.py
@@ -609,7 +609,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
- parallel_config=self._parallel_config,
+ # Reference: https://github.com/huggingface/diffusers/pull/12909
+ parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -622,7 +623,8 @@ def apply_rotary_emb(
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
- parallel_config=self._parallel_config,
+ # Reference: https://github.com/huggingface/diffusers/pull/12909
+ parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py
index 5fcc1a176d1b..e64db23f3831 100644
--- a/src/diffusers/modular_pipelines/__init__.py
+++ b/src/diffusers/modular_pipelines/__init__.py
@@ -63,6 +63,8 @@
"QwenImageEditAutoBlocks",
"QwenImageEditPlusModularPipeline",
"QwenImageEditPlusAutoBlocks",
+ "QwenImageLayeredModularPipeline",
+ "QwenImageLayeredAutoBlocks",
]
_import_structure["z_image"] = [
"ZImageAutoBlocks",
@@ -96,6 +98,8 @@
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
+ QwenImageLayeredAutoBlocks,
+ QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py
index cb7e8fb73697..e16abb382313 100644
--- a/src/diffusers/modular_pipelines/components_manager.py
+++ b/src/diffusers/modular_pipelines/components_manager.py
@@ -160,7 +160,10 @@ def __call__(self, hooks, model_id, model, execution_device):
if len(hooks) == 0:
return []
- current_module_size = model.get_memory_footprint()
+ try:
+ current_module_size = model.get_memory_footprint()
+ except AttributeError:
+ raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.")
device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
@@ -703,7 +706,20 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
if not is_accelerate_available():
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
- # TODO: add a warning if mem_get_info isn't available on `device`.
+ if device is None:
+ device = get_device()
+ if not isinstance(device, torch.device):
+ device = torch.device(device)
+
+ device_type = device.type
+ device_module = getattr(torch, device_type, torch.cuda)
+ if not hasattr(device_module, "mem_get_info"):
+ raise NotImplementedError(
+ f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}."
+ )
+
+ if device.index is None:
+ device = torch.device(f"{device.type}:{0}")
for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
@@ -711,11 +727,7 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
- if device is None:
- device = get_device()
- device = torch.device(device)
- if device.index is None:
- device = torch.device(f"{device.type}:{0}")
+
all_hooks = []
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):
diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py
index 8309eebfeb37..45b1c6bc136f 100644
--- a/src/diffusers/modular_pipelines/flux/inputs.py
+++ b/src/diffusers/modular_pipelines/flux/inputs.py
@@ -121,7 +121,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
return components, state
-# Adapted from `QwenImageInputsDynamicStep`
+# Adapted from `QwenImageAdditionalInputsStep`
class FluxInputsDynamicStep(ModularPipelineBlocks):
model_name = "flux"
diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py
index 4f142a453f3b..f848afe9a3ae 100644
--- a/src/diffusers/modular_pipelines/mellon_node_utils.py
+++ b/src/diffusers/modular_pipelines/mellon_node_utils.py
@@ -4,7 +4,7 @@
# Simple typed wrapper for parameter overrides
from dataclasses import asdict, dataclass
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Union
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from huggingface_hub.utils import (
@@ -23,10 +23,18 @@
@dataclass(frozen=True)
class MellonParam:
"""
- Parameter definition for Mellon nodes.
-
- Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with MellonParam(name="...",
- label="...", type="...").
+ Parameter definition for Mellon nodes.
+
+ Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with
+ MellonParam(name="...", label="...", type="...").
+
+ Example:
+ ```python
+ # Custom param
+ MellonParam(name="my_param", label="My Param", type="float", default=0.5)
+ # Output in Mellon node definition:
+ # "my_param": {"label": "My Param", "type": "float", "default": 0.5}
+ ```
"""
name: str
@@ -42,55 +50,165 @@ class MellonParam:
fieldOptions: Optional[Dict[str, Any]] = None
onChange: Any = None
onSignal: Any = None
+ required_block_params: Optional[Union[str, List[str]]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict for Mellon schema, excluding None values and name."""
data = asdict(self)
- return {k: v for k, v in data.items() if v is not None and k != "name"}
+ return {k: v for k, v in data.items() if v is not None and k not in ("name", "required_block_params")}
@classmethod
def image(cls) -> "MellonParam":
- return cls(name="image", label="Image", type="image", display="input")
+ """
+ Image input parameter.
+
+ Mellon node definition:
+ "image": {"label": "Image", "type": "image", "display": "input"}
+ """
+ return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"])
@classmethod
def images(cls) -> "MellonParam":
- return cls(name="images", label="Images", type="image", display="output")
+ """
+ Images output parameter.
+
+ Mellon node definition:
+ "images": {"label": "Images", "type": "image", "display": "output"}
+ """
+ return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"])
@classmethod
def control_image(cls, display: str = "input") -> "MellonParam":
- return cls(name="control_image", label="Control Image", type="image", display=display)
+ """
+ Control image parameter for ControlNet.
+
+ Mellon node definition (display="input"):
+ "control_image": {"label": "Control Image", "type": "image", "display": "input"}
+ """
+ return cls(
+ name="control_image",
+ label="Control Image",
+ type="image",
+ display=display,
+ required_block_params=["control_image"],
+ )
@classmethod
def latents(cls, display: str = "input") -> "MellonParam":
- return cls(name="latents", label="Latents", type="latents", display=display)
+ """
+ Latents parameter.
+
+ Mellon node definition (display="input"):
+ "latents": {"label": "Latents", "type": "latents", "display": "input"}
+
+ Mellon node definition (display="output"):
+ "latents": {"label": "Latents", "type": "latents", "display": "output"}
+ """
+ return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"])
@classmethod
def image_latents(cls, display: str = "input") -> "MellonParam":
- return cls(name="image_latents", label="Image Latents", type="latents", display=display)
+ """
+ Image latents parameter for img2img workflows.
+
+ Mellon node definition (display="input"):
+ "image_latents": {"label": "Image Latents", "type": "latents", "display": "input"}
+ """
+ return cls(
+ name="image_latents",
+ label="Image Latents",
+ type="latents",
+ display=display,
+ required_block_params=["image_latents"],
+ )
+
+ @classmethod
+ def first_frame_latents(cls, display: str = "input") -> "MellonParam":
+ """
+ First frame latents for video generation.
+
+ Mellon node definition (display="input"):
+ "first_frame_latents": {"label": "First Frame Latents", "type": "latents", "display": "input"}
+ """
+ return cls(
+ name="first_frame_latents",
+ label="First Frame Latents",
+ type="latents",
+ display=display,
+ required_block_params=["first_frame_latents"],
+ )
@classmethod
def image_latents_with_strength(cls) -> "MellonParam":
+ """
+ Image latents with strength-based onChange behavior. When connected, shows strength slider; when disconnected,
+ shows height/width.
+
+ Mellon node definition:
+ "image_latents": {
+ "label": "Image Latents", "type": "latents", "display": "input", "onChange": {"false": ["height",
+ "width"], "true": ["strength"]}
+ }
+ """
return cls(
name="image_latents",
label="Image Latents",
type="latents",
display="input",
onChange={"false": ["height", "width"], "true": ["strength"]},
+ required_block_params=["image_latents", "strength"],
)
@classmethod
def latents_preview(cls) -> "MellonParam":
"""
- `Latents Preview` is a special output parameter that is used to preview the latents in the UI.
+ Latents preview output for visualizing latents in the UI.
+
+ Mellon node definition:
+ "latents_preview": {"label": "Latents Preview", "type": "latent", "display": "output"}
"""
return cls(name="latents_preview", label="Latents Preview", type="latent", display="output")
@classmethod
def embeddings(cls, display: str = "output") -> "MellonParam":
+ """
+ Text embeddings parameter.
+
+ Mellon node definition (display="output"):
+ "embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "output"}
+
+ Mellon node definition (display="input"):
+ "embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "input"}
+ """
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
+ @classmethod
+ def image_embeds(cls, display: str = "output") -> "MellonParam":
+ """
+ Image embeddings parameter for IP-Adapter workflows.
+
+ Mellon node definition (display="output"):
+ "image_embeds": {"label": "Image Embeddings", "type": "image_embeds", "display": "output"}
+ """
+ return cls(
+ name="image_embeds",
+ label="Image Embeddings",
+ type="image_embeds",
+ display=display,
+ required_block_params=["image_embeds"],
+ )
+
@classmethod
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
+ """
+ ControlNet conditioning scale slider.
+
+ Mellon node definition (default=0.5):
+ "controlnet_conditioning_scale": {
+ "label": "Controlnet Conditioning Scale", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0,
+ "step": 0.01
+ }
+ """
return cls(
name="controlnet_conditioning_scale",
label="Controlnet Conditioning Scale",
@@ -99,10 +217,20 @@ def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
min=0.0,
max=1.0,
step=0.01,
+ required_block_params=["controlnet_conditioning_scale"],
)
@classmethod
def control_guidance_start(cls, default: float = 0.0) -> "MellonParam":
+ """
+ Control guidance start timestep.
+
+ Mellon node definition (default=0.0):
+ "control_guidance_start": {
+ "label": "Control Guidance Start", "type": "float", "default": 0.0, "min": 0.0, "max": 1.0, "step":
+ 0.01
+ }
+ """
return cls(
name="control_guidance_start",
label="Control Guidance Start",
@@ -111,10 +239,19 @@ def control_guidance_start(cls, default: float = 0.0) -> "MellonParam":
min=0.0,
max=1.0,
step=0.01,
+ required_block_params=["control_guidance_start"],
)
@classmethod
def control_guidance_end(cls, default: float = 1.0) -> "MellonParam":
+ """
+ Control guidance end timestep.
+
+ Mellon node definition (default=1.0):
+ "control_guidance_end": {
+ "label": "Control Guidance End", "type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01
+ }
+ """
return cls(
name="control_guidance_end",
label="Control Guidance End",
@@ -123,22 +260,73 @@ def control_guidance_end(cls, default: float = 1.0) -> "MellonParam":
min=0.0,
max=1.0,
step=0.01,
+ required_block_params=["control_guidance_end"],
)
@classmethod
def prompt(cls, default: str = "") -> "MellonParam":
- return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea")
+ """
+ Text prompt input as textarea.
+
+ Mellon node definition (default=""):
+ "prompt": {"label": "Prompt", "type": "string", "default": "", "display": "textarea"}
+ """
+ return cls(
+ name="prompt",
+ label="Prompt",
+ type="string",
+ default=default,
+ display="textarea",
+ required_block_params=["prompt"],
+ )
@classmethod
def negative_prompt(cls, default: str = "") -> "MellonParam":
- return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea")
+ """
+ Negative prompt input as textarea.
+
+ Mellon node definition (default=""):
+ "negative_prompt": {"label": "Negative Prompt", "type": "string", "default": "", "display": "textarea"}
+ """
+ return cls(
+ name="negative_prompt",
+ label="Negative Prompt",
+ type="string",
+ default=default,
+ display="textarea",
+ required_block_params=["negative_prompt"],
+ )
@classmethod
def strength(cls, default: float = 0.5) -> "MellonParam":
- return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01)
+ """
+ Denoising strength for img2img.
+
+ Mellon node definition (default=0.5):
+ "strength": {"label": "Strength", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}
+ """
+ return cls(
+ name="strength",
+ label="Strength",
+ type="float",
+ default=default,
+ min=0.0,
+ max=1.0,
+ step=0.01,
+ required_block_params=["strength"],
+ )
@classmethod
def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
+ """
+ CFG guidance scale slider.
+
+ Mellon node definition (default=5.0):
+ "guidance_scale": {
+ "label": "Guidance Scale", "type": "float", "display": "slider", "default": 5.0, "min": 1.0, "max":
+ 30.0, "step": 0.1
+ }
+ """
return cls(
name="guidance_scale",
label="Guidance Scale",
@@ -152,95 +340,273 @@ def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
@classmethod
def height(cls, default: int = 1024) -> "MellonParam":
- return cls(name="height", label="Height", type="int", default=default, min=64, step=8)
+ """
+ Image height in pixels.
+
+ Mellon node definition (default=1024):
+ "height": {"label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8}
+ """
+ return cls(
+ name="height",
+ label="Height",
+ type="int",
+ default=default,
+ min=64,
+ step=8,
+ required_block_params=["height"],
+ )
@classmethod
def width(cls, default: int = 1024) -> "MellonParam":
- return cls(name="width", label="Width", type="int", default=default, min=64, step=8)
+ """
+ Image width in pixels.
+
+ Mellon node definition (default=1024):
+ "width": {"label": "Width", "type": "int", "default": 1024, "min": 64, "step": 8}
+ """
+ return cls(
+ name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]
+ )
@classmethod
def seed(cls, default: int = 0) -> "MellonParam":
- return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random")
+ """
+ Random seed with randomize button.
+
+ Mellon node definition (default=0):
+ "seed": {
+ "label": "Seed", "type": "int", "default": 0, "min": 0, "max": 4294967295, "display": "random"
+ }
+ """
+ return cls(
+ name="seed",
+ label="Seed",
+ type="int",
+ default=default,
+ min=0,
+ max=4294967295,
+ display="random",
+ required_block_params=["generator"],
+ )
@classmethod
def num_inference_steps(cls, default: int = 25) -> "MellonParam":
+ """
+ Number of denoising steps slider.
+
+ Mellon node definition (default=25):
+ "num_inference_steps": {
+ "label": "Steps", "type": "int", "default": 25, "min": 1, "max": 100, "display": "slider"
+ }
+ """
return cls(
- name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
+ name="num_inference_steps",
+ label="Steps",
+ type="int",
+ default=default,
+ min=1,
+ max=100,
+ display="slider",
+ required_block_params=["num_inference_steps"],
)
+ @classmethod
+ def num_frames(cls, default: int = 81) -> "MellonParam":
+ """
+ Number of video frames slider.
+
+ Mellon node definition (default=81):
+ "num_frames": {"label": "Frames", "type": "int", "default": 81, "min": 1, "max": 480, "display": "slider"}
+ """
+ return cls(
+ name="num_frames",
+ label="Frames",
+ type="int",
+ default=default,
+ min=1,
+ max=480,
+ display="slider",
+ required_block_params=["num_frames"],
+ )
+
+ @classmethod
+ def layers(cls, default: int = 4) -> "MellonParam":
+ """
+ Number of layers slider (for layered diffusion).
+
+ Mellon node definition (default=4):
+ "layers": {"label": "Layers", "type": "int", "default": 4, "min": 1, "max": 10, "display": "slider"}
+ """
+ return cls(
+ name="layers",
+ label="Layers",
+ type="int",
+ default=default,
+ min=1,
+ max=10,
+ display="slider",
+ required_block_params=["layers"],
+ )
+
+ @classmethod
+ def videos(cls) -> "MellonParam":
+ """
+ Video output parameter.
+
+ Mellon node definition:
+ "videos": {"label": "Videos", "type": "video", "display": "output"}
+ """
+ return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"])
+
@classmethod
def vae(cls) -> "MellonParam":
"""
- VAE model info dict.
+ VAE model input.
- Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
- the actual model.
+ Mellon node definition:
+ "vae": {"label": "VAE", "type": "diffusers_auto_model", "display": "input"}
+
+ Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
+ components.get_one(model_id) to retrieve the actual model.
+ """
+ return cls(
+ name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]
+ )
+
+ @classmethod
+ def image_encoder(cls) -> "MellonParam":
"""
- return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
+ Image encoder model input.
+
+ Mellon node definition:
+ "image_encoder": {"label": "Image Encoder", "type": "diffusers_auto_model", "display": "input"}
+
+ Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
+ components.get_one(model_id) to retrieve the actual model.
+ """
+ return cls(
+ name="image_encoder",
+ label="Image Encoder",
+ type="diffusers_auto_model",
+ display="input",
+ required_block_params=["image_encoder"],
+ )
@classmethod
def unet(cls) -> "MellonParam":
"""
- Denoising model (UNet/Transformer) info dict.
+ Denoising model (UNet/Transformer) input.
+
+ Mellon node definition:
+ "unet": {"label": "Denoise Model", "type": "diffusers_auto_model", "display": "input"}
- Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
- the actual model.
+ Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
+ components.get_one(model_id) to retrieve the actual model.
"""
return cls(name="unet", label="Denoise Model", type="diffusers_auto_model", display="input")
@classmethod
def scheduler(cls) -> "MellonParam":
"""
- Scheduler model info dict.
+ Scheduler model input.
+
+ Mellon node definition:
+ "scheduler": {"label": "Scheduler", "type": "diffusers_auto_model", "display": "input"}
- Contains keys like 'model_id', 'repo_id' etc. Use components.get_one(model_id) to retrieve the actual
- scheduler.
+ Note: The value received is a model info dict with keys like 'model_id', 'repo_id'. Use
+ components.get_one(model_id) to retrieve the actual scheduler.
"""
return cls(name="scheduler", label="Scheduler", type="diffusers_auto_model", display="input")
@classmethod
def controlnet(cls) -> "MellonParam":
"""
- ControlNet model info dict.
+ ControlNet model input.
- Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
- the actual model.
+ Mellon node definition:
+ "controlnet": {"label": "ControlNet Model", "type": "diffusers_auto_model", "display": "input"}
+
+ Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
+ components.get_one(model_id) to retrieve the actual model.
"""
- return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input")
+ return cls(
+ name="controlnet",
+ label="ControlNet Model",
+ type="diffusers_auto_model",
+ display="input",
+ required_block_params=["controlnet"],
+ )
@classmethod
def text_encoders(cls) -> "MellonParam":
"""
- Dict of text encoder model info dicts.
+ Text encoders dict input (multiple encoders).
+
+ Mellon node definition:
+ "text_encoders": {"label": "Text Encoders", "type": "diffusers_auto_models", "display": "input"}
- Structure: {
- 'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
- 'repo_id': '...'
- } Use components.get_one(model_id) to retrieve each model.
+ Note: The value received is a dict of model info dicts:
+ {
+ 'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
+ 'repo_id': '...'
+ }
+ Use components.get_one(model_id) to retrieve each model.
"""
- return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input")
+ return cls(
+ name="text_encoders",
+ label="Text Encoders",
+ type="diffusers_auto_models",
+ display="input",
+ required_block_params=["text_encoder"],
+ )
@classmethod
def controlnet_bundle(cls, display: str = "input") -> "MellonParam":
"""
- ControlNet bundle containing model info and processed control inputs.
+ ControlNet bundle containing model and processed control inputs. Output from ControlNet node, input to Denoise
+ node.
- Structure: {
- 'controlnet': {'model_id': ..., ...}, # controlnet model info dict 'control_image': ..., # processed
- control image/embeddings 'controlnet_conditioning_scale': ..., ... # other inputs expected by denoise
- blocks
- }
+ Mellon node definition (display="input"):
+ "controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "input"}
- Output from Controlnet node, input to Denoise node.
+ Mellon node definition (display="output"):
+ "controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "output"}
+
+ Note: The value is a dict containing:
+ {
+ 'controlnet': {'model_id': ..., ...}, # controlnet model info 'control_image': ..., # processed control
+ image/embeddings 'controlnet_conditioning_scale': ..., # and other denoise block inputs
+ }
"""
- return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display)
+ return cls(
+ name="controlnet_bundle",
+ label="ControlNet",
+ type="custom_controlnet",
+ display=display,
+ required_block_params="controlnet_image",
+ )
@classmethod
def ip_adapter(cls) -> "MellonParam":
+ """
+ IP-Adapter input.
+
+ Mellon node definition:
+ "ip_adapter": {"label": "IP Adapter", "type": "custom_ip_adapter", "display": "input"}
+ """
return cls(name="ip_adapter", label="IP Adapter", type="custom_ip_adapter", display="input")
@classmethod
def guider(cls) -> "MellonParam":
+ """
+ Custom guider input. When connected, hides the guidance_scale slider.
+
+ Mellon node definition:
+ "guider": {
+ "label": "Guider", "type": "custom_guider", "display": "input", "onChange": {false: ["guidance_scale"],
+ true: []}
+ }
+ """
return cls(
name="guider",
label="Guider",
@@ -251,9 +617,96 @@ def guider(cls) -> "MellonParam":
@classmethod
def doc(cls) -> "MellonParam":
+ """
+ Documentation output for inspecting the underlying modular pipeline.
+
+ Mellon node definition:
+ "doc": {"label": "Doc", "type": "string", "display": "output"}
+ """
return cls(name="doc", label="Doc", type="string", display="output")
+DEFAULT_NODE_SPECS = {
+ "controlnet": None,
+ "denoise": {
+ "inputs": [
+ MellonParam.embeddings(display="input"),
+ MellonParam.width(),
+ MellonParam.height(),
+ MellonParam.seed(),
+ MellonParam.num_inference_steps(),
+ MellonParam.num_frames(),
+ MellonParam.guidance_scale(),
+ MellonParam.strength(),
+ MellonParam.image_latents_with_strength(),
+ MellonParam.image_latents(),
+ MellonParam.first_frame_latents(),
+ MellonParam.controlnet_bundle(display="input"),
+ ],
+ "model_inputs": [
+ MellonParam.unet(),
+ MellonParam.guider(),
+ MellonParam.scheduler(),
+ ],
+ "outputs": [
+ MellonParam.latents(display="output"),
+ MellonParam.latents_preview(),
+ MellonParam.doc(),
+ ],
+ "required_inputs": ["embeddings"],
+ "required_model_inputs": ["unet", "scheduler"],
+ "block_name": "denoise",
+ },
+ "vae_encoder": {
+ "inputs": [
+ MellonParam.image(),
+ ],
+ "model_inputs": [
+ MellonParam.vae(),
+ ],
+ "outputs": [
+ MellonParam.image_latents(display="output"),
+ MellonParam.doc(),
+ ],
+ "required_inputs": ["image"],
+ "required_model_inputs": ["vae"],
+ "block_name": "vae_encoder",
+ },
+ "text_encoder": {
+ "inputs": [
+ MellonParam.prompt(),
+ MellonParam.negative_prompt(),
+ ],
+ "model_inputs": [
+ MellonParam.text_encoders(),
+ ],
+ "outputs": [
+ MellonParam.embeddings(display="output"),
+ MellonParam.doc(),
+ ],
+ "required_inputs": ["prompt"],
+ "required_model_inputs": ["text_encoders"],
+ "block_name": "text_encoder",
+ },
+ "decoder": {
+ "inputs": [
+ MellonParam.latents(display="input"),
+ ],
+ "model_inputs": [
+ MellonParam.vae(),
+ ],
+ "outputs": [
+ MellonParam.images(),
+ MellonParam.videos(),
+ MellonParam.doc(),
+ ],
+ "required_inputs": ["latents"],
+ "required_model_inputs": ["vae"],
+ "block_name": "decode",
+ },
+}
+
+
def mark_required(label: str, marker: str = " *") -> str:
"""Add required marker to label if not already present."""
if label.endswith(marker):
@@ -428,20 +881,42 @@ def __init__(
default_dtype: Default dtype (e.g., "float16", "bfloat16")
"""
# Convert all node specs to Mellon format immediately
- self.node_params = {}
- for node_type, spec in node_specs.items():
- if spec is None:
- self.node_params[node_type] = None
- else:
- self.node_params[node_type] = node_spec_to_mellon_dict(spec, node_type)
+ self.node_specs = node_specs
self.label = label
self.default_repo = default_repo
self.default_dtype = default_dtype
+ @property
+ def node_params(self) -> Dict[str, Any]:
+ """Lazily compute node_params from node_specs."""
+ if self.node_specs is None:
+ return self._node_params
+
+ params = {}
+ for node_type, spec in self.node_specs.items():
+ if spec is None:
+ params[node_type] = None
+ else:
+ params[node_type] = node_spec_to_mellon_dict(spec, node_type)
+ return params
+
def __repr__(self) -> str:
- node_types = list(self.node_params.keys())
- return f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r}, node_params={node_types})"
+ lines = [
+ f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"
+ ]
+ for node_type, spec in self.node_specs.items():
+ if spec is None:
+ lines.append(f" {node_type}: None")
+ else:
+ inputs = [p.name for p in spec.get("inputs", [])]
+ model_inputs = [p.name for p in spec.get("model_inputs", [])]
+ outputs = [p.name for p in spec.get("outputs", [])]
+ lines.append(f" {node_type}:")
+ lines.append(f" inputs: {inputs}")
+ lines.append(f" model_inputs: {model_inputs}")
+ lines.append(f" outputs: {outputs}")
+ return "\n".join(lines)
def to_dict(self) -> Dict[str, Any]:
"""Convert to a JSON-serializable dictionary."""
@@ -460,7 +935,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "MellonPipelineConfig":
Note: The mellon_params are already in Mellon format when loading from JSON.
"""
instance = cls.__new__(cls)
- instance.node_params = data.get("node_params", {})
+ instance.node_specs = None
+ instance._node_params = data.get("node_params", {})
instance.label = data.get("label", "")
instance.default_repo = data.get("default_repo", "")
instance.default_dtype = data.get("default_dtype", "")
@@ -592,3 +1068,85 @@ def load(
return cls.from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.")
+
+ @classmethod
+ def from_blocks(
+ cls,
+ blocks,
+ template: Dict[str, Optional[Dict[str, Any]]] = None,
+ label: str = "",
+ default_repo: str = "",
+ default_dtype: str = "bfloat16",
+ ) -> "MellonPipelineConfig":
+ """
+ Create MellonPipelineConfig by matching template against actual pipeline blocks.
+ """
+ if template is None:
+ template = DEFAULT_NODE_SPECS
+
+ sub_block_map = dict(blocks.sub_blocks)
+
+ def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]:
+ """Filter template spec params based on what the block actually supports."""
+ block_input_names = set(block.input_names)
+ block_output_names = set(block.intermediate_output_names)
+ block_component_names = set(block.component_names)
+
+ filtered_inputs = [
+ p
+ for p in template_spec.get("inputs", [])
+ if p.required_block_params is None
+ or all(name in block_input_names for name in p.required_block_params)
+ ]
+ filtered_model_inputs = [
+ p
+ for p in template_spec.get("model_inputs", [])
+ if p.required_block_params is None
+ or all(name in block_component_names for name in p.required_block_params)
+ ]
+ filtered_outputs = [
+ p
+ for p in template_spec.get("outputs", [])
+ if p.required_block_params is None
+ or all(name in block_output_names for name in p.required_block_params)
+ ]
+
+ filtered_input_names = {p.name for p in filtered_inputs}
+ filtered_model_input_names = {p.name for p in filtered_model_inputs}
+
+ filtered_required_inputs = [
+ r for r in template_spec.get("required_inputs", []) if r in filtered_input_names
+ ]
+ filtered_required_model_inputs = [
+ r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names
+ ]
+
+ return {
+ "inputs": filtered_inputs,
+ "model_inputs": filtered_model_inputs,
+ "outputs": filtered_outputs,
+ "required_inputs": filtered_required_inputs,
+ "required_model_inputs": filtered_required_model_inputs,
+ "block_name": template_spec.get("block_name"),
+ }
+
+ # Build node specs
+ node_specs = {}
+ for node_type, template_spec in template.items():
+ if template_spec is None:
+ node_specs[node_type] = None
+ continue
+
+ block_name = template_spec.get("block_name")
+ if block_name is None or block_name not in sub_block_map:
+ node_specs[node_type] = None
+ continue
+
+ node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name])
+
+ return cls(
+ node_specs=node_specs,
+ label=label or getattr(blocks, "model_name", ""),
+ default_repo=default_repo,
+ default_dtype=default_dtype,
+ )
diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py
index c5fa4cf9921f..d857fd040955 100644
--- a/src/diffusers/modular_pipelines/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/modular_pipeline.py
@@ -62,6 +62,7 @@
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
+ ("qwenimage-layered", "QwenImageLayeredModularPipeline"),
("z-image", "ZImageModularPipeline"),
]
)
@@ -231,7 +232,7 @@ def format_value(v):
class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""
- Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
+ Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks,
LoopSequentialPipelineBlocks
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
@@ -527,9 +528,10 @@ def doc(self):
)
-class AutoPipelineBlocks(ModularPipelineBlocks):
+class ConditionalPipelineBlocks(ModularPipelineBlocks):
"""
- A Pipeline Blocks that automatically selects a block to run based on the inputs.
+ A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
+ `select_block` method to define the logic for selecting the block.
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
@@ -539,12 +541,13 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
Attributes:
block_classes: List of block classes to be used
block_names: List of prefixes for each block
- block_trigger_inputs: List of input names that trigger specific blocks, with None for default
+ block_trigger_inputs: List of input names that select_block() uses to determine which block to run
"""
block_classes = []
block_names = []
block_trigger_inputs = []
+ default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
def __init__(self):
sub_blocks = InsertableDict()
@@ -554,26 +557,15 @@ def __init__(self):
else:
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
- if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
+ if not (len(self.block_classes) == len(self.block_names)):
raise ValueError(
- f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
+ f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same."
)
- default_blocks = [t for t in self.block_trigger_inputs if t is None]
- # can only have 1 or 0 default block, and has to put in the last
- # the order of blocks matters here because the first block with matching trigger will be dispatched
- # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
- # as long as mask is provided, it is inpaint; if only image is provided, it is img2img
- if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
+ if self.default_block_name is not None and self.default_block_name not in self.block_names:
raise ValueError(
- f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
- "in block_trigger_inputs."
+ f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}"
)
- # Map trigger inputs to block objects
- self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
- self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
- self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
-
@property
def model_name(self):
return next(iter(self.sub_blocks.values())).model_name
@@ -602,8 +594,10 @@ def expected_configs(self):
@property
def required_inputs(self) -> List[str]:
- if None not in self.block_trigger_inputs:
+ # no default block means this conditional block can be skipped entirely
+ if self.default_block_name is None:
return []
+
first_block = next(iter(self.sub_blocks.values()))
required_by_all = set(getattr(first_block, "required_inputs", set()))
@@ -614,7 +608,6 @@ def required_inputs(self) -> List[str]:
return list(required_by_all)
- # YiYi TODO: add test for this
@property
def inputs(self) -> List[Tuple[str, Any]]:
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
@@ -639,36 +632,9 @@ def outputs(self) -> List[str]:
combined_outputs = self.combine_outputs(*named_outputs)
return combined_outputs
- @torch.no_grad()
- def __call__(self, pipeline, state: PipelineState) -> PipelineState:
- # Find default block first (if any)
-
- block = self.trigger_to_block_map.get(None)
- for input_name in self.block_trigger_inputs:
- if input_name is not None and state.get(input_name) is not None:
- block = self.trigger_to_block_map[input_name]
- break
-
- if block is None:
- logger.info(f"skipping auto block: {self.__class__.__name__}")
- return pipeline, state
-
- try:
- logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
- return block(pipeline, state)
- except Exception as e:
- error_msg = (
- f"\nError in block: {block.__class__.__name__}\n"
- f"Error details: {str(e)}\n"
- f"Traceback:\n{traceback.format_exc()}"
- )
- logger.error(error_msg)
- raise
-
- def _get_trigger_inputs(self):
+ def _get_trigger_inputs(self) -> set:
"""
- Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
- block_trigger_inputs values
+ Returns a set of all unique trigger input values found in this block and nested blocks.
"""
def fn_recursive_get_trigger(blocks):
@@ -676,9 +642,8 @@ def fn_recursive_get_trigger(blocks):
if blocks is not None:
for name, block in blocks.items():
- # Check if current block has trigger inputs(i.e. auto block)
+ # Check if current block has block_trigger_inputs
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
- # Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has sub_blocks, recursively check them
@@ -688,15 +653,57 @@ def fn_recursive_get_trigger(blocks):
return trigger_values
- trigger_inputs = set(self.block_trigger_inputs)
- trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
+ # Start with this block's block_trigger_inputs
+ all_triggers = {t for t in self.block_trigger_inputs if t is not None}
+ # Add nested triggers
+ all_triggers.update(fn_recursive_get_trigger(self.sub_blocks))
- return trigger_inputs
+ return all_triggers
@property
def trigger_inputs(self):
+ """All trigger inputs including from nested blocks."""
return self._get_trigger_inputs()
+ def select_block(self, **kwargs) -> Optional[str]:
+ """
+ Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
+ for selecting the block.
+
+ Args:
+ **kwargs: Trigger input names and their values from the state.
+
+ Returns:
+ Optional[str]: The name of the block to run, or None to use default/skip.
+ """
+ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.")
+
+ @torch.no_grad()
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
+ trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None}
+ block_name = self.select_block(**trigger_kwargs)
+
+ if block_name is None:
+ block_name = self.default_block_name
+
+ if block_name is None:
+ logger.info(f"skipping conditional block: {self.__class__.__name__}")
+ return pipeline, state
+
+ block = self.sub_blocks[block_name]
+
+ try:
+ logger.info(f"Running block: {block.__class__.__name__}")
+ return block(pipeline, state)
+ except Exception as e:
+ error_msg = (
+ f"\nError in block: {block.__class__.__name__}\n"
+ f"Error details: {str(e)}\n"
+ f"Traceback:\n{traceback.format_exc()}"
+ )
+ logger.error(error_msg)
+ raise
+
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
@@ -708,7 +715,7 @@ def __repr__(self):
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
- header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
+ header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -729,31 +736,20 @@ def __repr__(self):
expected_configs = getattr(self, "expected_configs", [])
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
- # Blocks section - moved to the end with simplified format
+ # Blocks section
blocks_str = " Sub-Blocks:\n"
for i, (name, block) in enumerate(self.sub_blocks.items()):
- # Get trigger input for this block
- trigger = None
- if hasattr(self, "block_to_trigger_map"):
- trigger = self.block_to_trigger_map.get(name)
- # Format the trigger info
- if trigger is None:
- trigger_str = "[default]"
- elif isinstance(trigger, (list, tuple)):
- trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
- else:
- trigger_str = f"[trigger: {trigger}]"
- # For AutoPipelineBlocks, add bullet points
- blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
+ if name == self.default_block_name:
+ addtional_str = " [default]"
else:
- # For SequentialPipelineBlocks, show execution order
- blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+ addtional_str = ""
+ blocks_str += f" • {name}{addtional_str} ({block.__class__.__name__})\n"
# Add block description
- desc_lines = block.description.split("\n")
- indented_desc = desc_lines[0]
- if len(desc_lines) > 1:
- indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
+ block_desc_lines = block.description.split("\n")
+ indented_desc = block_desc_lines[0]
+ if len(block_desc_lines) > 1:
+ indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n"
# Build the representation with conditional sections
@@ -784,6 +780,35 @@ def doc(self):
)
+class AutoPipelineBlocks(ConditionalPipelineBlocks):
+ """
+ A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
+ raise ValueError(
+ f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
+ )
+
+ @property
+ def default_block_name(self) -> Optional[str]:
+ """Derive default_block_name from block_trigger_inputs (None entry)."""
+ if None in self.block_trigger_inputs:
+ idx = self.block_trigger_inputs.index(None)
+ return self.block_names[idx]
+ return None
+
+ def select_block(self, **kwargs) -> Optional[str]:
+ """Select block based on which trigger input is present (not None)."""
+ for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names):
+ if trigger_input is not None and kwargs.get(trigger_input) is not None:
+ return block_name
+ return None
+
+
class SequentialPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
@@ -885,7 +910,8 @@ def _get_inputs(self):
# Only add outputs if the block cannot be skipped
should_add_outputs = True
- if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
+ if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None:
+ # ConditionalPipelineBlocks without default can be skipped
should_add_outputs = False
if should_add_outputs:
@@ -948,8 +974,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
def _get_trigger_inputs(self):
"""
- Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
- block_trigger_inputs values
+ Returns a set of all unique trigger input values found in the blocks.
"""
def fn_recursive_get_trigger(blocks):
@@ -957,9 +982,8 @@ def fn_recursive_get_trigger(blocks):
if blocks is not None:
for name, block in blocks.items():
- # Check if current block has trigger inputs(i.e. auto block)
+ # Check if current block has block_trigger_inputs (ConditionalPipelineBlocks)
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
- # Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has sub_blocks, recursively check them
@@ -975,82 +999,84 @@ def fn_recursive_get_trigger(blocks):
def trigger_inputs(self):
return self._get_trigger_inputs()
- def _traverse_trigger_blocks(self, trigger_inputs):
- # Convert trigger_inputs to a set for easier manipulation
- active_triggers = set(trigger_inputs)
+ def _traverse_trigger_blocks(self, active_inputs):
+ """
+ Traverse blocks and select which ones would run given the active inputs.
+
+ Args:
+ active_inputs: Dict of input names to values that are "present"
+
+ Returns:
+ OrderedDict of block_name -> block that would execute
+ """
- def fn_recursive_traverse(block, block_name, active_triggers):
+ def fn_recursive_traverse(block, block_name, active_inputs):
result_blocks = OrderedDict()
- # sequential(include loopsequential) or PipelineBlock
- if not hasattr(block, "block_trigger_inputs"):
- if block.sub_blocks:
- # sequential or LoopSequentialPipelineBlocks (keep traversing)
- for sub_block_name, sub_block in block.sub_blocks.items():
- blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
- blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
- blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
- result_blocks.update(blocks_to_update)
+ # ConditionalPipelineBlocks (includes AutoPipelineBlocks)
+ if isinstance(block, ConditionalPipelineBlocks):
+ trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
+ selected_block_name = block.select_block(**trigger_kwargs)
+
+ if selected_block_name is None:
+ selected_block_name = block.default_block_name
+
+ if selected_block_name is None:
+ return result_blocks
+
+ selected_block = block.sub_blocks[selected_block_name]
+
+ if selected_block.sub_blocks:
+ result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
else:
- # PipelineBlock
- result_blocks[block_name] = block
- # Add this block's output names to active triggers if defined
- if hasattr(block, "outputs"):
- active_triggers.update(out.name for out in block.outputs)
+ result_blocks[block_name] = selected_block
+ if hasattr(selected_block, "outputs"):
+ for out in selected_block.outputs:
+ active_inputs[out.name] = True
+
return result_blocks
- # auto
+ # SequentialPipelineBlocks or LoopSequentialPipelineBlocks
+ if block.sub_blocks:
+ for sub_block_name, sub_block in block.sub_blocks.items():
+ blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
+ blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
+ result_blocks.update(blocks_to_update)
else:
- # Find first block_trigger_input that matches any value in our active_triggers
- this_block = None
- for trigger_input in block.block_trigger_inputs:
- if trigger_input is not None and trigger_input in active_triggers:
- this_block = block.trigger_to_block_map[trigger_input]
- break
-
- # If no matches found, try to get the default (None) block
- if this_block is None and None in block.block_trigger_inputs:
- this_block = block.trigger_to_block_map[None]
-
- if this_block is not None:
- # sequential/auto (keep traversing)
- if this_block.sub_blocks:
- result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
- else:
- # PipelineBlock
- result_blocks[block_name] = this_block
- # Add this block's output names to active triggers if defined
- # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
- if hasattr(this_block, "outputs"):
- active_triggers.update(out.name for out in this_block.outputs)
+ result_blocks[block_name] = block
+ if hasattr(block, "outputs"):
+ for out in block.outputs:
+ active_inputs[out.name] = True
return result_blocks
all_blocks = OrderedDict()
for block_name, block in self.sub_blocks.items():
- blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
+ blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(blocks_to_update)
return all_blocks
- def get_execution_blocks(self, *trigger_inputs):
- trigger_inputs_all = self.trigger_inputs
+ def get_execution_blocks(self, **kwargs):
+ """
+ Get the blocks that would execute given the specified inputs.
- if trigger_inputs is not None:
- if not isinstance(trigger_inputs, (list, tuple, set)):
- trigger_inputs = [trigger_inputs]
- invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
- if invalid_inputs:
- logger.warning(
- f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
- )
- trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
+ Args:
+ **kwargs: Input names and values. Only trigger inputs affect block selection.
+ Pass any inputs that would be non-None at runtime.
- if trigger_inputs is None:
- if None in trigger_inputs_all:
- trigger_inputs = [None]
- else:
- trigger_inputs = [trigger_inputs_all[0]]
- blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
+ Returns:
+ SequentialPipelineBlocks containing only the blocks that would execute
+
+ Example:
+ # Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
+ image=image)
+
+ # Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
+ """
+ # Filter out None values
+ active_inputs = {k: v for k, v in kwargs.items() if v is not None}
+
+ blocks_triggered = self._traverse_trigger_blocks(active_inputs)
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
def __repr__(self):
@@ -1067,7 +1093,7 @@ def __repr__(self):
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
# Get first trigger input as example
example_input = next(t for t in self.trigger_inputs if t is not None)
- header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
+ header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -1091,22 +1117,8 @@ def __repr__(self):
# Blocks section - moved to the end with simplified format
blocks_str = " Sub-Blocks:\n"
for i, (name, block) in enumerate(self.sub_blocks.items()):
- # Get trigger input for this block
- trigger = None
- if hasattr(self, "block_to_trigger_map"):
- trigger = self.block_to_trigger_map.get(name)
- # Format the trigger info
- if trigger is None:
- trigger_str = "[default]"
- elif isinstance(trigger, (list, tuple)):
- trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
- else:
- trigger_str = f"[trigger: {trigger}]"
- # For AutoPipelineBlocks, add bullet points
- blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
- else:
- # For SequentialPipelineBlocks, show execution order
- blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+ # show execution order
+ blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# Add block description
desc_lines = block.description.split("\n")
@@ -1230,15 +1242,9 @@ def _get_inputs(self):
if inp.name not in outputs and inp not in inputs:
inputs.append(inp)
- # Only add outputs if the block cannot be skipped
- should_add_outputs = True
- if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
- should_add_outputs = False
-
- if should_add_outputs:
- # Add this block's outputs
- block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
- outputs.update(block_intermediate_outputs)
+ # Add this block's outputs
+ block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
+ outputs.update(block_intermediate_outputs)
for input_param in inputs:
if input_param.name in self.required_inputs:
@@ -1295,6 +1301,14 @@ def __init__(self):
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
+ # Validate that sub_blocks are only leaf blocks
+ for block_name, block in self.sub_blocks.items():
+ if block.sub_blocks:
+ raise ValueError(
+ f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). "
+ f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks."
+ )
+
@classmethod
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
"""
diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py
deleted file mode 100644
index f7ee1dd3097b..000000000000
--- a/src/diffusers/modular_pipelines/node_utils.py
+++ /dev/null
@@ -1,661 +0,0 @@
-import json
-import logging
-import os
-from pathlib import Path
-from typing import List, Optional, Tuple, Union
-
-import numpy as np
-import PIL
-import torch
-
-from ..configuration_utils import ConfigMixin
-from ..image_processor import PipelineImageInput
-from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
-from .modular_pipeline_utils import InputParam
-
-
-logger = logging.getLogger(__name__)
-
-# YiYi Notes: this is actually for SDXL, put it here for now
-SDXL_INPUTS_SCHEMA = {
- "prompt": InputParam(
- "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
- ),
- "prompt_2": InputParam(
- "prompt_2",
- type_hint=Union[str, List[str]],
- description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
- ),
- "negative_prompt": InputParam(
- "negative_prompt",
- type_hint=Union[str, List[str]],
- description="The prompt or prompts not to guide the image generation",
- ),
- "negative_prompt_2": InputParam(
- "negative_prompt_2",
- type_hint=Union[str, List[str]],
- description="The negative prompt or prompts for text_encoder_2",
- ),
- "cross_attention_kwargs": InputParam(
- "cross_attention_kwargs",
- type_hint=Optional[dict],
- description="Kwargs dictionary passed to the AttentionProcessor",
- ),
- "clip_skip": InputParam(
- "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
- ),
- "image": InputParam(
- "image",
- type_hint=PipelineImageInput,
- required=True,
- description="The image(s) to modify for img2img or inpainting",
- ),
- "mask_image": InputParam(
- "mask_image",
- type_hint=PipelineImageInput,
- required=True,
- description="Mask image for inpainting, white pixels will be repainted",
- ),
- "generator": InputParam(
- "generator",
- type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
- description="Generator(s) for deterministic generation",
- ),
- "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
- "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
- "num_images_per_prompt": InputParam(
- "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
- ),
- "num_inference_steps": InputParam(
- "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
- ),
- "timesteps": InputParam(
- "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
- ),
- "sigmas": InputParam(
- "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
- ),
- "denoising_end": InputParam(
- "denoising_end",
- type_hint=Optional[float],
- description="Fraction of denoising process to complete before termination",
- ),
- # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
- "strength": InputParam(
- "strength", type_hint=float, default=0.3, description="How much to transform the reference image"
- ),
- "denoising_start": InputParam(
- "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
- ),
- "latents": InputParam(
- "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
- ),
- "padding_mask_crop": InputParam(
- "padding_mask_crop",
- type_hint=Optional[Tuple[int, int]],
- description="Size of margin in crop for image and mask",
- ),
- "original_size": InputParam(
- "original_size",
- type_hint=Optional[Tuple[int, int]],
- description="Original size of the image for SDXL's micro-conditioning",
- ),
- "target_size": InputParam(
- "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
- ),
- "negative_original_size": InputParam(
- "negative_original_size",
- type_hint=Optional[Tuple[int, int]],
- description="Negative conditioning based on image resolution",
- ),
- "negative_target_size": InputParam(
- "negative_target_size",
- type_hint=Optional[Tuple[int, int]],
- description="Negative conditioning based on target resolution",
- ),
- "crops_coords_top_left": InputParam(
- "crops_coords_top_left",
- type_hint=Tuple[int, int],
- default=(0, 0),
- description="Top-left coordinates for SDXL's micro-conditioning",
- ),
- "negative_crops_coords_top_left": InputParam(
- "negative_crops_coords_top_left",
- type_hint=Tuple[int, int],
- default=(0, 0),
- description="Negative conditioning crop coordinates",
- ),
- "aesthetic_score": InputParam(
- "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
- ),
- "negative_aesthetic_score": InputParam(
- "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
- ),
- "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
- "output_type": InputParam(
- "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
- ),
- "ip_adapter_image": InputParam(
- "ip_adapter_image",
- type_hint=PipelineImageInput,
- required=True,
- description="Image(s) to be used as IP adapter",
- ),
- "control_image": InputParam(
- "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
- ),
- "control_guidance_start": InputParam(
- "control_guidance_start",
- type_hint=Union[float, List[float]],
- default=0.0,
- description="When ControlNet starts applying",
- ),
- "control_guidance_end": InputParam(
- "control_guidance_end",
- type_hint=Union[float, List[float]],
- default=1.0,
- description="When ControlNet stops applying",
- ),
- "controlnet_conditioning_scale": InputParam(
- "controlnet_conditioning_scale",
- type_hint=Union[float, List[float]],
- default=1.0,
- description="Scale factor for ControlNet outputs",
- ),
- "guess_mode": InputParam(
- "guess_mode",
- type_hint=bool,
- default=False,
- description="Enables ControlNet encoder to recognize input without prompts",
- ),
- "control_mode": InputParam(
- "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
- ),
-}
-
-SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
- "prompt_embeds": InputParam(
- "prompt_embeds",
- type_hint=torch.Tensor,
- required=True,
- description="Text embeddings used to guide image generation",
- ),
- "negative_prompt_embeds": InputParam(
- "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
- ),
- "pooled_prompt_embeds": InputParam(
- "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
- ),
- "negative_pooled_prompt_embeds": InputParam(
- "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
- ),
- "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
- "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
- "preprocess_kwargs": InputParam(
- "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
- ),
- "latents": InputParam(
- "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
- ),
- "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
- "num_inference_steps": InputParam(
- "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
- ),
- "latent_timestep": InputParam(
- "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
- ),
- "image_latents": InputParam(
- "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
- ),
- "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
- "masked_image_latents": InputParam(
- "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
- ),
- "add_time_ids": InputParam(
- "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
- ),
- "negative_add_time_ids": InputParam(
- "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
- ),
- "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
- "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
- "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
- "ip_adapter_embeds": InputParam(
- "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
- ),
- "negative_ip_adapter_embeds": InputParam(
- "negative_ip_adapter_embeds",
- type_hint=List[torch.Tensor],
- description="Negative image embeddings for IP-Adapter",
- ),
- "images": InputParam(
- "images",
- type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
- required=True,
- description="Generated images",
- ),
-}
-
-SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
-
-
-DEFAULT_PARAM_MAPS = {
- "prompt": {
- "label": "Prompt",
- "type": "string",
- "default": "a bear sitting in a chair drinking a milkshake",
- "display": "textarea",
- },
- "negative_prompt": {
- "label": "Negative Prompt",
- "type": "string",
- "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
- "display": "textarea",
- },
- "num_inference_steps": {
- "label": "Steps",
- "type": "int",
- "default": 25,
- "min": 1,
- "max": 1000,
- },
- "seed": {
- "label": "Seed",
- "type": "int",
- "default": 0,
- "min": 0,
- "display": "random",
- },
- "width": {
- "label": "Width",
- "type": "int",
- "display": "text",
- "default": 1024,
- "min": 8,
- "max": 8192,
- "step": 8,
- "group": "dimensions",
- },
- "height": {
- "label": "Height",
- "type": "int",
- "display": "text",
- "default": 1024,
- "min": 8,
- "max": 8192,
- "step": 8,
- "group": "dimensions",
- },
- "images": {
- "label": "Images",
- "type": "image",
- "display": "output",
- },
- "image": {
- "label": "Image",
- "type": "image",
- "display": "input",
- },
-}
-
-DEFAULT_TYPE_MAPS = {
- "int": {
- "type": "int",
- "default": 0,
- "min": 0,
- },
- "float": {
- "type": "float",
- "default": 0.0,
- "min": 0.0,
- },
- "str": {
- "type": "string",
- "default": "",
- },
- "bool": {
- "type": "boolean",
- "default": False,
- },
- "image": {
- "type": "image",
- },
-}
-
-DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
-DEFAULT_CATEGORY = "Modular Diffusers"
-DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
-DEFAULT_PARAMS_GROUPS_KEYS = {
- "text_encoders": ["text_encoder", "tokenizer"],
- "ip_adapter_embeds": ["ip_adapter_embeds"],
- "prompt_embeddings": ["prompt_embeds"],
-}
-
-
-def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
- """
- Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
- "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
- """
- if name is None:
- return None
- for group_name, group_keys in group_params_keys.items():
- for group_key in group_keys:
- if group_key in name:
- return group_name
- return None
-
-
-class ModularNode(ConfigMixin):
- """
- A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
- around a ModularPipelineBlocks object.
-
- > [!WARNING] > This is an experimental feature and is likely to change in the future.
- """
-
- config_name = "node_config.json"
-
- @classmethod
- def from_pretrained(
- cls,
- pretrained_model_name_or_path: str,
- trust_remote_code: Optional[bool] = None,
- **kwargs,
- ):
- blocks = ModularPipelineBlocks.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
- )
- return cls(blocks, **kwargs)
-
- def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
- self.blocks = blocks
-
- if label is None:
- label = self.blocks.__class__.__name__
- # blocks param name -> mellon param name
- self.name_mapping = {}
-
- input_params = {}
- # pass or create a default param dict for each input
- # e.g. for prompt,
- # prompt = {
- # "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
- # "label": "Prompt",
- # "type": "string",
- # "default": "a bear sitting in a chair drinking a milkshake",
- # "display": "textarea"}
- # if type is not specified, it'll be a "custom" param of its own type
- # e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
- # it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
- # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
- inputs = self.blocks.inputs + self.blocks.intermediate_inputs
- for inp in inputs:
- param = kwargs.pop(inp.name, None)
- if param:
- # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
- input_params[inp.name] = param
- mellon_name = param.pop("name", inp.name)
- if mellon_name != inp.name:
- self.name_mapping[inp.name] = mellon_name
- continue
-
- if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
- continue
-
- if inp.name in DEFAULT_PARAM_MAPS:
- # first check if it's in the default param map, if so, directly use that
- param = DEFAULT_PARAM_MAPS[inp.name].copy()
- elif get_group_name(inp.name):
- param = get_group_name(inp.name)
- if inp.name not in self.name_mapping:
- self.name_mapping[inp.name] = param
- else:
- # if not, check if it's in the SDXL input schema, if so,
- # 1. use the type hint to determine the type
- # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
- if inp.type_hint is not None:
- type_str = str(inp.type_hint).lower()
- else:
- inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
- type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
- for type_key, type_param in DEFAULT_TYPE_MAPS.items():
- if type_key in type_str:
- param = type_param.copy()
- param["label"] = inp.name
- param["display"] = "input"
- break
- else:
- param = inp.name
- # add the param dict to the inp_params dict
- input_params[inp.name] = param
-
- component_params = {}
- for comp in self.blocks.expected_components:
- param = kwargs.pop(comp.name, None)
- if param:
- component_params[comp.name] = param
- mellon_name = param.pop("name", comp.name)
- if mellon_name != comp.name:
- self.name_mapping[comp.name] = mellon_name
- continue
-
- to_exclude = False
- for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
- if exclude_key in comp.name:
- to_exclude = True
- break
- if to_exclude:
- continue
-
- if get_group_name(comp.name):
- param = get_group_name(comp.name)
- if comp.name not in self.name_mapping:
- self.name_mapping[comp.name] = param
- elif comp.name in DEFAULT_MODEL_KEYS:
- param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
- else:
- param = comp.name
- # add the param dict to the model_params dict
- component_params[comp.name] = param
-
- output_params = {}
- if isinstance(self.blocks, SequentialPipelineBlocks):
- last_block_name = list(self.blocks.sub_blocks.keys())[-1]
- outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
- else:
- outputs = self.blocks.intermediate_outputs
-
- for out in outputs:
- param = kwargs.pop(out.name, None)
- if param:
- output_params[out.name] = param
- mellon_name = param.pop("name", out.name)
- if mellon_name != out.name:
- self.name_mapping[out.name] = mellon_name
- continue
-
- if out.name in DEFAULT_PARAM_MAPS:
- param = DEFAULT_PARAM_MAPS[out.name].copy()
- param["display"] = "output"
- else:
- group_name = get_group_name(out.name)
- if group_name:
- param = group_name
- if out.name not in self.name_mapping:
- self.name_mapping[out.name] = param
- else:
- param = out.name
- # add the param dict to the outputs dict
- output_params[out.name] = param
-
- if len(kwargs) > 0:
- logger.warning(f"Unused kwargs: {kwargs}")
-
- register_dict = {
- "category": category,
- "label": label,
- "input_params": input_params,
- "component_params": component_params,
- "output_params": output_params,
- "name_mapping": self.name_mapping,
- }
- self.register_to_config(**register_dict)
-
- def setup(self, components_manager, collection=None):
- self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
- self._components_manager = components_manager
-
- @property
- def mellon_config(self):
- return self._convert_to_mellon_config()
-
- def _convert_to_mellon_config(self):
- node = {}
- node["label"] = self.config.label
- node["category"] = self.config.category
-
- node_param = {}
- for inp_name, inp_param in self.config.input_params.items():
- if inp_name in self.name_mapping:
- mellon_name = self.name_mapping[inp_name]
- else:
- mellon_name = inp_name
- if isinstance(inp_param, str):
- param = {
- "label": inp_param,
- "type": inp_param,
- "display": "input",
- }
- else:
- param = inp_param
-
- if mellon_name not in node_param:
- node_param[mellon_name] = param
- else:
- logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
-
- for comp_name, comp_param in self.config.component_params.items():
- if comp_name in self.name_mapping:
- mellon_name = self.name_mapping[comp_name]
- else:
- mellon_name = comp_name
- if isinstance(comp_param, str):
- param = {
- "label": comp_param,
- "type": comp_param,
- "display": "input",
- }
- else:
- param = comp_param
-
- if mellon_name not in node_param:
- node_param[mellon_name] = param
- else:
- logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
-
- for out_name, out_param in self.config.output_params.items():
- if out_name in self.name_mapping:
- mellon_name = self.name_mapping[out_name]
- else:
- mellon_name = out_name
- if isinstance(out_param, str):
- param = {
- "label": out_param,
- "type": out_param,
- "display": "output",
- }
- else:
- param = out_param
-
- if mellon_name not in node_param:
- node_param[mellon_name] = param
- else:
- logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
- node["params"] = node_param
- return node
-
- def save_mellon_config(self, file_path):
- """
- Save the Mellon configuration to a JSON file.
-
- Args:
- file_path (str or Path): Path where the JSON file will be saved
-
- Returns:
- Path: Path to the saved config file
- """
- file_path = Path(file_path)
-
- # Create directory if it doesn't exist
- os.makedirs(file_path.parent, exist_ok=True)
-
- # Create a combined dictionary with module definition and name mapping
- config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
-
- # Save the config to file
- with open(file_path, "w", encoding="utf-8") as f:
- json.dump(config, f, indent=2)
-
- logger.info(f"Mellon config and name mapping saved to {file_path}")
-
- return file_path
-
- @classmethod
- def load_mellon_config(cls, file_path):
- """
- Load a Mellon configuration from a JSON file.
-
- Args:
- file_path (str or Path): Path to the JSON file containing Mellon config
-
- Returns:
- dict: The loaded combined configuration containing 'module' and 'name_mapping'
- """
- file_path = Path(file_path)
-
- if not file_path.exists():
- raise FileNotFoundError(f"Config file not found: {file_path}")
-
- with open(file_path, "r", encoding="utf-8") as f:
- config = json.load(f)
-
- logger.info(f"Mellon config loaded from {file_path}")
-
- return config
-
- def process_inputs(self, **kwargs):
- params_components = {}
- for comp_name, comp_param in self.config.component_params.items():
- logger.debug(f"component: {comp_name}")
- mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
- if mellon_comp_name in kwargs:
- if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
- comp = kwargs[mellon_comp_name].pop(comp_name)
- else:
- comp = kwargs.pop(mellon_comp_name)
- if comp:
- params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
-
- params_run = {}
- for inp_name, inp_param in self.config.input_params.items():
- logger.debug(f"input: {inp_name}")
- mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
- if mellon_inp_name in kwargs:
- if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
- inp = kwargs[mellon_inp_name].pop(inp_name)
- else:
- inp = kwargs.pop(mellon_inp_name)
- if inp is not None:
- params_run[inp_name] = inp
-
- return_output_names = list(self.config.output_params.keys())
-
- return params_components, params_run, return_output_names
-
- def execute(self, **kwargs):
- params_components, params_run, return_output_names = self.process_inputs(**kwargs)
-
- self.pipeline.update_components(**params_components)
- output = self.pipeline(**params_run, output=return_output_names)
- return output
diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py
index ae4ec4799fbc..2b01a5b5a4b5 100644
--- a/src/diffusers/modular_pipelines/qwenimage/__init__.py
+++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py
@@ -21,27 +21,27 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
- _import_structure["encoders"] = ["QwenImageTextEncoderStep"]
- _import_structure["modular_blocks"] = [
- "ALL_BLOCKS",
+ _import_structure["modular_blocks_qwenimage"] = [
"AUTO_BLOCKS",
- "CONTROLNET_BLOCKS",
- "EDIT_AUTO_BLOCKS",
- "EDIT_BLOCKS",
- "EDIT_INPAINT_BLOCKS",
- "EDIT_PLUS_AUTO_BLOCKS",
- "EDIT_PLUS_BLOCKS",
- "IMAGE2IMAGE_BLOCKS",
- "INPAINT_BLOCKS",
- "TEXT2IMAGE_BLOCKS",
"QwenImageAutoBlocks",
+ ]
+ _import_structure["modular_blocks_qwenimage_edit"] = [
+ "EDIT_AUTO_BLOCKS",
"QwenImageEditAutoBlocks",
+ ]
+ _import_structure["modular_blocks_qwenimage_edit_plus"] = [
+ "EDIT_PLUS_AUTO_BLOCKS",
"QwenImageEditPlusAutoBlocks",
]
+ _import_structure["modular_blocks_qwenimage_layered"] = [
+ "LAYERED_AUTO_BLOCKS",
+ "QwenImageLayeredAutoBlocks",
+ ]
_import_structure["modular_pipeline"] = [
"QwenImageEditModularPipeline",
"QwenImageEditPlusModularPipeline",
"QwenImageModularPipeline",
+ "QwenImageLayeredModularPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -51,28 +51,26 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
- from .encoders import (
- QwenImageTextEncoderStep,
- )
- from .modular_blocks import (
- ALL_BLOCKS,
+ from .modular_blocks_qwenimage import (
AUTO_BLOCKS,
- CONTROLNET_BLOCKS,
- EDIT_AUTO_BLOCKS,
- EDIT_BLOCKS,
- EDIT_INPAINT_BLOCKS,
- EDIT_PLUS_AUTO_BLOCKS,
- EDIT_PLUS_BLOCKS,
- IMAGE2IMAGE_BLOCKS,
- INPAINT_BLOCKS,
- TEXT2IMAGE_BLOCKS,
QwenImageAutoBlocks,
+ )
+ from .modular_blocks_qwenimage_edit import (
+ EDIT_AUTO_BLOCKS,
QwenImageEditAutoBlocks,
+ )
+ from .modular_blocks_qwenimage_edit_plus import (
+ EDIT_PLUS_AUTO_BLOCKS,
QwenImageEditPlusAutoBlocks,
)
+ from .modular_blocks_qwenimage_layered import (
+ LAYERED_AUTO_BLOCKS,
+ QwenImageLayeredAutoBlocks,
+ )
from .modular_pipeline import (
QwenImageEditModularPipeline,
QwenImageEditPlusModularPipeline,
+ QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
)
else:
diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
index bd92d403539e..d9c8cbb01d18 100644
--- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
+++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
@@ -23,7 +23,7 @@
from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
-from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
@@ -113,7 +113,9 @@ def get_timesteps(scheduler, num_inference_steps, strength):
return timesteps, num_inference_steps - t_start
-# Prepare Latents steps
+# ====================
+# 1. PREPARE LATENTS
+# ====================
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
@@ -207,6 +209,98 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
+class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "qwenimage-layered"
+
+ @property
+ def description(self) -> str:
+ return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("latents"),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="layers", default=4),
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="generator"),
+ InputParam(
+ name="batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ name="dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of the model inputs, can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="latents",
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ height=block_state.height,
+ width=block_state.width,
+ vae_scale_factor=components.vae_scale_factor,
+ )
+
+ device = components._execution_device
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+
+ # we can update the height and width here since it's used to generate the initial
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+
+ shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width)
+ if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if block_state.latents is None:
+ block_state.latents = randn_tensor(
+ shape, generator=block_state.generator, device=device, dtype=block_state.dtype
+ )
+ block_state.latents = components.pachifier.pack_latents(block_state.latents)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -351,7 +445,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
-# Set Timesteps steps
+# ====================
+# 2. SET TIMESTEPS
+# ====================
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
@@ -359,7 +455,7 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
@property
def description(self) -> str:
- return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
+ return "Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -420,12 +516,70 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
+class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "qwenimage-layered"
+
+ @property
+ def description(self) -> str:
+ return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50, type_hint=int),
+ InputParam("sigmas", type_hint=List[float]),
+ InputParam("image_latents", required=True, type_hint=torch.Tensor),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="timesteps", type_hint=torch.Tensor),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ # Layered-specific mu calculation
+ base_seqlen = 256 * 256 / 16 / 16 # = 256
+ mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5
+
+ # Default sigmas if not provided
+ sigmas = (
+ np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
+ if block_state.sigmas is None
+ else block_state.sigmas
+ )
+
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ components.scheduler,
+ block_state.num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ components.scheduler.set_begin_index(0)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
- return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
+ return "Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -493,7 +647,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
-# other inputs for denoiser
+# ====================
+# 3. OTHER INPUTS FOR DENOISER
+# ====================
## RoPE inputs for denoiser
@@ -522,21 +678,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
+ kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
- OutputParam(
- name="txt_seq_lens",
- kwargs_type="denoiser_input_fields",
- type_hint=List[int],
- description="The sequence lengths of the prompt embeds, used for RoPE calculation",
- ),
- OutputParam(
- name="negative_txt_seq_lens",
- kwargs_type="denoiser_input_fields",
- type_hint=List[int],
- description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
- ),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
@@ -551,14 +696,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
)
]
] * block_state.batch_size
- block_state.txt_seq_lens = (
- block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
- )
- block_state.negative_txt_seq_lens = (
- block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
- if block_state.negative_prompt_embeds_mask is not None
- else None
- )
self.set_block_state(state, block_state)
@@ -589,21 +726,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
+ kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
- OutputParam(
- name="txt_seq_lens",
- kwargs_type="denoiser_input_fields",
- type_hint=List[int],
- description="The sequence lengths of the prompt embeds, used for RoPE calculation",
- ),
- OutputParam(
- name="negative_txt_seq_lens",
- kwargs_type="denoiser_input_fields",
- type_hint=List[int],
- description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
- ),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
@@ -625,33 +751,69 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
]
] * block_state.batch_size
- block_state.txt_seq_lens = (
- block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
- )
- block_state.negative_txt_seq_lens = (
- block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
- if block_state.negative_prompt_embeds_mask is not None
- else None
- )
-
self.set_block_state(state, block_state)
return components, state
-class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
+class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage-edit-plus"
+ @property
+ def description(self) -> str:
+ return (
+ "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n"
+ "Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n"
+ "Should be placed after prepare_latents step."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(name="image_height", required=True, type_hint=List[int]),
+ InputParam(name="image_width", required=True, type_hint=List[int]),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[List[Tuple[int, int, int]]],
+ description="The shapes of the image latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ ]
+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_scale_factor = components.vae_scale_factor
+
+ # Edit Plus: image_height and image_width are lists
block_state.img_shapes = [
[
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
*[
- (1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
- for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
+ (1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2)
+ for img_height, img_width in zip(block_state.image_height, block_state.image_width)
],
]
] * block_state.batch_size
@@ -670,6 +832,87 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
+class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage-layered"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(name="layers", required=True),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ type_hint=List[List[Tuple[int, int, int]]],
+ kwargs_type="denoiser_input_fields",
+ description="The shapes of the image latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ type_hint=List[int],
+ kwargs_type="denoiser_input_fields",
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ type_hint=List[int],
+ kwargs_type="denoiser_input_fields",
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="additional_t_cond",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="The additional t cond, used for RoPE calculation",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ # All shapes are the same for Layered
+ shape = (
+ 1,
+ block_state.height // components.vae_scale_factor // 2,
+ block_state.width // components.vae_scale_factor // 2,
+ )
+
+ # layers+1 output shapes + 1 condition shape (all same)
+ block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size
+
+ # txt_seq_lens
+ block_state.txt_seq_lens = (
+ block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
+ )
+ block_state.negative_txt_seq_lens = (
+ block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
+ if block_state.negative_prompt_embeds_mask is not None
+ else None
+ )
+
+ block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
## ControlNet inputs for denoiser
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
model_name = "qwenimage"
diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py
index 6e145f18550a..24a88ebfca3c 100644
--- a/src/diffusers/modular_pipelines/qwenimage/decoders.py
+++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py
@@ -24,12 +24,13 @@
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
-from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
logger = logging.get_logger(__name__)
+# after denoising loop (unpack latents)
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -71,6 +72,46 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
+class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
+ model_name = "qwenimage-layered"
+
+ @property
+ def description(self) -> str:
+ return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("latents", required=True, type_hint=torch.Tensor),
+ InputParam("height", required=True, type_hint=int),
+ InputParam("width", required=True, type_hint=int),
+ InputParam("layers", required=True, type_hint=int),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W)
+ block_state.latents = components.pachifier.unpack_latents(
+ block_state.latents,
+ block_state.height,
+ block_state.width,
+ block_state.layers,
+ components.vae_scale_factor,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+# decode step
class QwenImageDecoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -135,6 +176,81 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
+class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage-layered"
+
+ @property
+ def description(self) -> str:
+ return "Decode unpacked latents (B, C, layers+1, H, W) into layer images."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("latents", required=True, type_hint=torch.Tensor),
+ InputParam("output_type", default="pil", type_hint=str),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ latents = block_state.latents
+
+ # 1. VAE normalization
+ latents = latents.to(components.vae.dtype)
+ latents_mean = (
+ torch.tensor(components.vae.config.latents_mean)
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
+ 1, components.vae.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latents = latents / latents_std + latents_mean
+
+ # 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W)
+ b, c, f, h, w = latents.shape
+ # 3. Remove first frame (composite), keep layers frames
+ latents = latents[:, :, 1:]
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
+
+ # 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W)
+ image = components.vae.decode(latents, return_dict=False)[0]
+ image = image.squeeze(2)
+
+ # 5. Postprocess - returns flat list of B*layers images
+ image = components.image_processor.postprocess(image, output_type=block_state.output_type)
+
+ # 6. Chunk into list per batch item
+ images = []
+ for bidx in range(b):
+ images.append(image[bidx * f : (bidx + 1) * f])
+
+ block_state.images = images
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+# postprocess the decoded images
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
model_name = "qwenimage"
diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py
index 49acd2dc0295..d6bcb4a94f80 100644
--- a/src/diffusers/modular_pipelines/qwenimage/denoise.py
+++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
from typing import List, Tuple
import torch
@@ -28,7 +29,12 @@
logger = logging.get_logger(__name__)
+# ====================
+# 1. LOOP STEPS (run at each denoising step)
+# ====================
+
+# loop step:before denoiser
class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -60,7 +66,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
- model_name = "qwenimage"
+ model_name = "qwenimage-edit"
@property
def description(self) -> str:
@@ -149,7 +155,7 @@ def inputs(self) -> List[InputParam]:
kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs for the denoiser. "
- "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
+ "It should contain prompt_embeds/negative_prompt_embeds."
),
),
]
@@ -176,7 +182,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
img_shapes=block_state.img_shapes,
encoder_hidden_states=block_state.prompt_embeds,
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
- txt_seq_lens=block_state.txt_seq_lens,
return_dict=False,
)
@@ -185,6 +190,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
return components, block_state
+# loop step:denoiser
class QwenImageLoopDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -247,12 +253,15 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
- "txt_seq_lens": (
- getattr(block_state, "txt_seq_lens", None),
- getattr(block_state, "negative_txt_seq_lens", None),
- ),
}
+ transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
+ additional_cond_kwargs = {}
+ for field_name, field_value in block_state.denoiser_input_fields.items():
+ if field_name in transformer_args and field_name not in guider_inputs:
+ additional_cond_kwargs[field_name] = field_value
+ block_state.additional_cond_kwargs.update(additional_cond_kwargs)
+
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
@@ -264,7 +273,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=block_state.timestep / 1000,
- img_shapes=block_state.img_shapes,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
@@ -284,7 +292,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
- model_name = "qwenimage"
+ model_name = "qwenimage-edit"
@property
def description(self) -> str:
@@ -345,12 +353,15 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
- "txt_seq_lens": (
- getattr(block_state, "txt_seq_lens", None),
- getattr(block_state, "negative_txt_seq_lens", None),
- ),
}
+ transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
+ additional_cond_kwargs = {}
+ for field_name, field_value in block_state.denoiser_input_fields.items():
+ if field_name in transformer_args and field_name not in guider_inputs:
+ additional_cond_kwargs[field_name] = field_value
+ block_state.additional_cond_kwargs.update(additional_cond_kwargs)
+
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
@@ -362,7 +373,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=block_state.timestep / 1000,
- img_shapes=block_state.img_shapes,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
@@ -384,6 +394,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
return components, block_state
+# loop step:after denoiser
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -481,6 +492,9 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
return components, block_state
+# ====================
+# 2. DENOISE LOOP WRAPPER: define the denoising loop logic
+# ====================
class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "qwenimage"
@@ -537,8 +551,15 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
-# composing the denoising loops
+# ====================
+# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps
+# ====================
+
+
+# Qwen Image (text2image, image2image)
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
+ model_name = "qwenimage"
+
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopDenoiser,
@@ -559,8 +580,9 @@ def description(self) -> str:
)
-# composing the inpainting denoising loops
+# Qwen Image (inpainting)
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopDenoiser,
@@ -583,8 +605,9 @@ def description(self) -> str:
)
-# composing the controlnet denoising loops
+# Qwen Image (text2image, image2image) with controlnet
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopBeforeDenoiserControlNet,
@@ -607,8 +630,9 @@ def description(self) -> str:
)
-# composing the controlnet denoising loops
+# Qwen Image (inpainting) with controlnet
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopBeforeDenoiserControlNet,
@@ -639,8 +663,9 @@ def description(self) -> str:
)
-# composing the denoising loops
+# Qwen Image Edit (image2image)
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
+ model_name = "qwenimage-edit"
block_classes = [
QwenImageEditLoopBeforeDenoiser,
QwenImageEditLoopDenoiser,
@@ -661,7 +686,9 @@ def description(self) -> str:
)
+# Qwen Image Edit (inpainting)
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ model_name = "qwenimage-edit"
block_classes = [
QwenImageEditLoopBeforeDenoiser,
QwenImageEditLoopDenoiser,
@@ -682,3 +709,26 @@ def description(self) -> str:
" - `QwenImageLoopAfterDenoiserInpaint`\n"
"This block supports inpainting tasks for QwenImage Edit."
)
+
+
+# Qwen Image Layered (image2image)
+class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
+ model_name = "qwenimage-layered"
+ block_classes = [
+ QwenImageEditLoopBeforeDenoiser,
+ QwenImageEditLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageEditLoopBeforeDenoiser`\n"
+ " - `QwenImageEditLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports QwenImage Layered."
+ )
diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py
index b126a368bfdf..4b66dd32e521 100644
--- a/src/diffusers/modular_pipelines/qwenimage/encoders.py
+++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py
@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Text and VAE encoder blocks for QwenImage pipelines.
+"""
+
from typing import Dict, List, Optional, Union
import PIL
@@ -28,6 +32,17 @@
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline
+from .prompt_templates import (
+ QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE,
+ QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE,
+ QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX,
+ QWENIMAGE_EDIT_PROMPT_TEMPLATE,
+ QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX,
+ QWENIMAGE_LAYERED_CAPTION_PROMPT_CN,
+ QWENIMAGE_LAYERED_CAPTION_PROMPT_EN,
+ QWENIMAGE_PROMPT_TEMPLATE,
+ QWENIMAGE_PROMPT_TEMPLATE_START_IDX,
+)
logger = logging.get_logger(__name__)
@@ -45,8 +60,8 @@ def get_qwen_prompt_embeds(
text_encoder,
tokenizer,
prompt: Union[str, List[str]] = None,
- prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
- prompt_template_encode_start_idx: int = 34,
+ prompt_template_encode: str = QWENIMAGE_PROMPT_TEMPLATE,
+ prompt_template_encode_start_idx: int = QWENIMAGE_PROMPT_TEMPLATE_START_IDX,
tokenizer_max_length: int = 1024,
device: Optional[torch.device] = None,
):
@@ -86,8 +101,8 @@ def get_qwen_prompt_embeds_edit(
processor,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
- prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
- prompt_template_encode_start_idx: int = 64,
+ prompt_template_encode: str = QWENIMAGE_EDIT_PROMPT_TEMPLATE,
+ prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -133,9 +148,9 @@ def get_qwen_prompt_embeds_edit_plus(
processor,
prompt: Union[str, List[str]] = None,
image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None,
- prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
- img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
- prompt_template_encode_start_idx: int = 64,
+ prompt_template_encode: str = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE,
+ img_template_encode: str = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE,
+ prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -241,15 +256,18 @@ def encode_vae_image(
return image_latents
-class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
- model_name = "qwenimage"
-
- def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
- """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
+# ====================
+# 1. RESIZE
+# ====================
+class QwenImageEditResizeStep(ModularPipelineBlocks):
+ model_name = "qwenimage-edit"
- This block resizes an input image tensor and exposes the resized result under configurable input and output
- names. Use this when you need to wire the resize step to different image fields (e.g., "image",
- "control_image")
+ def __init__(
+ self,
+ input_name: str = "image",
+ output_name: str = "resized_image",
+ ):
+ """Create a configurable step for resizing images to the target area while maintaining the aspect ratio.
Args:
input_name (str, optional): Name of the image field to read from the
@@ -267,7 +285,7 @@ def __init__(self, input_name: str = "image", output_name: str = "resized_image"
@property
def description(self) -> str:
- return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
+ return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -321,89 +339,289 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
return components, state
-class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
- model_name = "qwenimage"
+class QwenImageLayeredResizeStep(ModularPipelineBlocks):
+ model_name = "qwenimage-layered"
def __init__(
self,
input_name: str = "image",
output_name: str = "resized_image",
- vae_image_output_name: str = "vae_image",
):
- """Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
-
- This block resizes an input image or a list input images and exposes the resized result under configurable
- input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
- "image", "control_image")
+ """Create a configurable step for resizing images to the target area while maintaining the aspect ratio.
Args:
input_name (str, optional): Name of the image field to read from the
pipeline state. Defaults to "image".
output_name (str, optional): Name of the resized image field to write
back to the pipeline state. Defaults to "resized_image".
- vae_image_output_name (str, optional): Name of the image field
- to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
- processes the input image(s) differently for the VL and the VAE.
"""
if not isinstance(input_name, str) or not isinstance(output_name, str):
raise ValueError(
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
)
- self.condition_image_size = 384 * 384
self._image_input_name = input_name
self._resized_image_output_name = output_name
- self._vae_image_output_name = vae_image_output_name
super().__init__()
+ @property
+ def description(self) -> str:
+ return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_resize_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
+ ),
+ InputParam(
+ name="resolution",
+ default=640,
+ type_hint=int,
+ description="The target area to resize the image to, can be 1024 or 640",
+ ),
+ ]
+
@property
def intermediate_outputs(self) -> List[OutputParam]:
- return super().intermediate_outputs + [
+ return [
OutputParam(
- name=self._vae_image_output_name,
- type_hint=List[PIL.Image.Image],
- description="The images to be processed which will be further used by the VAE encoder.",
+ name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
),
]
+ @staticmethod
+ def check_inputs(resolution: int):
+ if resolution not in [1024, 640]:
+ raise ValueError(f"Resolution must be 1024 or 640 but is {resolution}")
+
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
+ self.check_inputs(resolution=block_state.resolution)
+
images = getattr(block_state, self._image_input_name)
if not is_valid_image_imagelist(images):
raise ValueError(f"Images must be image or list of images but are {type(images)}")
- if (
- not isinstance(images, torch.Tensor)
- and isinstance(images, PIL.Image.Image)
- and not isinstance(images, list)
- ):
+ if is_valid_image(images):
images = [images]
- # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s
- condition_images = []
- vae_images = []
- for img in images:
- image_width, image_height = img.size
- condition_width, condition_height, _ = calculate_dimensions(
- self.condition_image_size, image_width / image_height
+ image_width, image_height = images[0].size
+ target_area = block_state.resolution * block_state.resolution
+ calculated_width, calculated_height, _ = calculate_dimensions(target_area, image_width / image_height)
+
+ resized_images = [
+ components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
+ for image in images
+ ]
+
+ setattr(block_state, self._resized_image_output_name, resized_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditPlusResizeStep(ModularPipelineBlocks):
+ """Resize each image independently based on its own aspect ratio. For QwenImage Edit Plus."""
+
+ model_name = "qwenimage-edit-plus"
+
+ def __init__(
+ self,
+ input_name: str = "image",
+ output_name: str = "resized_image",
+ target_area: int = 1024 * 1024,
+ ):
+ """Create a step for resizing images to a target area.
+
+ Each image is resized independently based on its own aspect ratio. This is suitable for Edit Plus where
+ multiple reference images can have different dimensions.
+
+ Args:
+ input_name (str, optional): Name of the image field to read. Defaults to "image".
+ output_name (str, optional): Name of the resized image field to write. Defaults to "resized_image".
+ target_area (int, optional): Target area in pixels. Defaults to 1024*1024.
+ """
+ if not isinstance(input_name, str) or not isinstance(output_name, str):
+ raise ValueError(
+ f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
)
- condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width))
- vae_images.append(img)
+ self._image_input_name = input_name
+ self._resized_image_output_name = output_name
+ self._target_area = target_area
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return (
+ f"Image Resize step that resizes {self._image_input_name} to target area {self._target_area}.\n"
+ "Each image is resized independently based on its own aspect ratio."
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_resize_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name=self._image_input_name,
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image(s) to resize",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ images = getattr(block_state, self._image_input_name)
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if is_valid_image(images):
+ images = [images]
+
+ # Resize each image independently based on its own aspect ratio
+ resized_images = []
+ for image in images:
+ image_width, image_height = image.size
+ calculated_width, calculated_height, _ = calculate_dimensions(
+ self._target_area, image_width / image_height
+ )
+ resized_images.append(
+ components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
+ )
+
+ setattr(block_state, self._resized_image_output_name, resized_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+# ====================
+# 2. GET IMAGE PROMPT
+# ====================
+class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks):
+ """
+ Auto-caption step that generates a text prompt from the input image if none is provided. Uses the VL model to
+ generate a description of the image.
+ """
+
+ model_name = "qwenimage-layered"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Auto-caption step that generates a text prompt from the input image if none is provided.\n"
+ "Uses the VL model (text_encoder) to generate a description of the image.\n"
+ "If prompt is already provided, this step passes through unchanged."
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
+ ComponentSpec("processor", Qwen2VLProcessor),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(name="image_caption_prompt_en", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_EN),
+ ConfigSpec(name="image_caption_prompt_cn", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_CN),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", type_hint=str, description="The prompt to encode"),
+ InputParam(
+ name="resized_image",
+ required=True,
+ type_hint=PIL.Image.Image,
+ description="The image to generate caption from, should be resized use the resize step",
+ ),
+ InputParam(
+ name="use_en_prompt",
+ default=False,
+ type_hint=bool,
+ description="Whether to use English prompt template",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ # If prompt is empty or None, generate caption from image
+ if block_state.prompt is None or block_state.prompt == "" or block_state.prompt == " ":
+ if block_state.use_en_prompt:
+ caption_prompt = components.config.image_caption_prompt_en
+ else:
+ caption_prompt = components.config.image_caption_prompt_cn
+
+ model_inputs = components.processor(
+ text=caption_prompt,
+ images=block_state.resized_image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ generated_ids = components.text_encoder.generate(**model_inputs, max_new_tokens=512)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
+ ]
+ output_text = components.processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )[0]
+
+ block_state.prompt = output_text.strip()
- setattr(block_state, self._resized_image_output_name, condition_images)
- setattr(block_state, self._vae_image_output_name, vae_images)
self.set_block_state(state, block_state)
return components, state
+# ====================
+# 3. TEXT ENCODER
+# ====================
class QwenImageTextEncoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
- return "Text Encoder step that generate text_embeddings to guide the image generation"
+ return "Text Encoder step that generates text embeddings to guide the image generation."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -421,11 +639,8 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
- ConfigSpec(
- name="prompt_template_encode",
- default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
- ),
- ConfigSpec(name="prompt_template_encode_start_idx", default=34),
+ ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_PROMPT_TEMPLATE),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_PROMPT_TEMPLATE_START_IDX),
ConfigSpec(name="tokenizer_max_length", default=1024),
]
@@ -532,7 +747,7 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
@property
def description(self) -> str:
- return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation"
+ return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -550,11 +765,8 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
- ConfigSpec(
- name="prompt_template_encode",
- default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
- ),
- ConfigSpec(name="prompt_template_encode_start_idx", default=64),
+ ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX),
]
@property
@@ -565,7 +777,7 @@ def inputs(self) -> List[InputParam]:
InputParam(
name="resized_image",
required=True,
- type_hint=torch.Tensor,
+ type_hint=PIL.Image.Image,
description="The image prompt to encode, should be resized using resize step",
),
]
@@ -647,23 +859,93 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
return components, state
-class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
- model_name = "qwenimage"
+class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
+ """Text encoder for QwenImage Edit Plus (VL encoding with multiple images)."""
+
+ model_name = "qwenimage-edit-plus"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together "
+ "to generate text embeddings for guiding image generation."
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
+ ComponentSpec("processor", Qwen2VLProcessor),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
- ConfigSpec(
- name="prompt_template_encode",
- default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE),
+ ConfigSpec(name="img_template_encode", default=QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
+ InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
+ InputParam(
+ name="resized_cond_image",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step",
),
- ConfigSpec(
- name="img_template_encode",
- default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The prompt embeddings",
+ ),
+ OutputParam(
+ name="prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The encoder attention mask",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings mask",
),
- ConfigSpec(name="prompt_template_encode_start_idx", default=64),
]
+ @staticmethod
+ def check_inputs(prompt, negative_prompt):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ negative_prompt is not None
+ and not isinstance(negative_prompt, str)
+ and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
@@ -676,7 +958,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
components.text_encoder,
components.processor,
prompt=block_state.prompt,
- image=block_state.resized_image,
+ image=block_state.resized_cond_image,
prompt_template_encode=components.config.prompt_template_encode,
img_template_encode=components.config.img_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
@@ -692,7 +974,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
components.text_encoder,
components.processor,
prompt=negative_prompt,
- image=block_state.resized_image,
+ image=block_state.resized_cond_image,
prompt_template_encode=components.config.prompt_template_encode,
img_template_encode=components.config.img_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
@@ -704,12 +986,15 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
return components, state
+# ====================
+# 4. IMAGE PREPROCESS
+# ====================
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
- return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep."
+ return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be resized to the given height and width."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -726,8 +1011,7 @@ def expected_components(self) -> List[ComponentSpec]:
def inputs(self) -> List[InputParam]:
return [
InputParam("mask_image", required=True),
- InputParam("resized_image"),
- InputParam("image"),
+ InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("padding_mask_crop"),
@@ -757,23 +1041,73 @@ def check_inputs(height, width, vae_scale_factor):
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
- if block_state.resized_image is None and block_state.image is None:
- raise ValueError("resized_image and image cannot be None at the same time")
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
- if block_state.resized_image is None:
- image = block_state.image
- self.check_inputs(
- height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
+ components.image_mask_processor.preprocess(
+ image=block_state.image,
+ mask=block_state.mask_image,
+ height=height,
+ width=width,
+ padding_mask_crop=block_state.padding_mask_crop,
)
- height = block_state.height or components.default_height
- width = block_state.width or components.default_width
- else:
- width, height = block_state.resized_image[0].size
- image = block_state.resized_image
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage-edit"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be resized first."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_mask_processor",
+ InpaintProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("mask_image", required=True),
+ InputParam("resized_image", required=True),
+ InputParam("padding_mask_crop"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="processed_image"),
+ OutputParam(name="processed_mask_image"),
+ OutputParam(
+ name="mask_overlay_kwargs",
+ type_hint=Dict,
+ description="The kwargs for the postprocess step to apply the mask overlay",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ width, height = block_state.resized_image[0].size
block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
components.image_mask_processor.preprocess(
- image=image,
+ image=block_state.resized_image,
mask=block_state.mask_image,
height=height,
width=width,
@@ -790,7 +1124,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
@property
def description(self) -> str:
- return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
+ return "Image Preprocess step. will resize the image to the given height and width."
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -805,7 +1139,11 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
- return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+ return [
+ InputParam("image", required=True),
+ InputParam("height"),
+ InputParam("width"),
+ ]
@property
def intermediate_outputs(self) -> List[OutputParam]:
@@ -823,22 +1161,58 @@ def check_inputs(height, width, vae_scale_factor):
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
- if block_state.resized_image is None and block_state.image is None:
- raise ValueError("resized_image and image cannot be None at the same time")
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
- if block_state.resized_image is None:
- image = block_state.image
- self.check_inputs(
- height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
- )
- height = block_state.height or components.default_height
- width = block_state.width or components.default_width
- else:
- width, height = block_state.resized_image[0].size
- image = block_state.resized_image
+ block_state.processed_image = components.image_processor.preprocess(
+ image=block_state.image,
+ height=height,
+ width=width,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage-edit"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step. Images needs to be resized first."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("resized_image", required=True),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ width, height = block_state.resized_image[0].size
block_state.processed_image = components.image_processor.preprocess(
- image=image,
+ image=block_state.resized_image,
height=height,
width=width,
)
@@ -847,59 +1221,64 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
return components, state
-class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
+class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks):
model_name = "qwenimage-edit-plus"
- def __init__(self):
- self.vae_image_size = 1024 * 1024
- super().__init__()
-
@property
def description(self) -> str:
- return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing."
+ return "Image Preprocess step. Images can be resized first using QwenImageEditResizeStep."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
@property
def inputs(self) -> List[InputParam]:
- return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+ return [InputParam("resized_image")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
- if block_state.vae_image is None and block_state.image is None:
- raise ValueError("`vae_image` and `image` cannot be None at the same time")
+ image = block_state.resized_image
- vae_image_sizes = None
- if block_state.vae_image is None:
- image = block_state.image
- self.check_inputs(
- height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
- )
- height = block_state.height or components.default_height
- width = block_state.width or components.default_width
- block_state.processed_image = components.image_processor.preprocess(
- image=image, height=height, width=width
+ is_image_list = isinstance(image, list)
+ if not is_image_list:
+ image = [image]
+
+ processed_images = []
+ for img in image:
+ img_width, img_height = img.size
+ processed_images.append(
+ components.image_processor.preprocess(image=img, height=img_height, width=img_width)
)
- else:
- # QwenImage Edit Plus can allow multiple input images with varied resolutions
- processed_images = []
- vae_image_sizes = []
- for img in block_state.vae_image:
- width, height = img.size
- vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
- vae_image_sizes.append((vae_width, vae_height))
- processed_images.append(
- components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
- )
+ block_state.processed_image = processed_images
+ if is_image_list:
block_state.processed_image = processed_images
-
- block_state.vae_image_sizes = vae_image_sizes
+ else:
+ block_state.processed_image = processed_images[0]
self.set_block_state(state, block_state)
return components, state
-class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
+# ====================
+# 5. VAE ENCODER
+# ====================
+class QwenImageVaeEncoderStep(ModularPipelineBlocks):
+ """VAE encoder that handles both single images and lists of images with varied resolutions."""
+
model_name = "qwenimage"
def __init__(
@@ -909,21 +1288,12 @@ def __init__(
):
"""Initialize a VAE encoder step for converting images to latent representations.
- Both the input and output names are configurable so this block can be configured to process to different image
- inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
+ Handles both single images and lists of images. When input is a list, outputs a list of latents. When input is
+ a single tensor, outputs a single latent tensor.
Args:
- input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
- Examples: "processed_image" or "processed_control_image"
- output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
- Examples: "image_latents" or "control_image_latents"
-
- Examples:
- # Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
-
- # Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
- input_name="processed_control_image", output_name="control_image_latents"
- )
+ input_name (str, optional): Name of the input image tensor or list. Defaults to "processed_image".
+ output_name (str, optional): Name of the output latent tensor or list. Defaults to "image_latents".
"""
self._image_input_name = input_name
self._image_latents_output_name = output_name
@@ -931,17 +1301,18 @@ def __init__(
@property
def description(self) -> str:
- return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+ return (
+ f"VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+ "Handles both single images and lists of images with varied resolutions."
+ )
@property
def expected_components(self) -> List[ComponentSpec]:
- components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
- return components
+ return [ComponentSpec("vae", AutoencoderKLQwenImage)]
@property
def inputs(self) -> List[InputParam]:
- inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
- return inputs
+ return [InputParam(self._image_input_name, required=True), InputParam("generator")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
@@ -949,46 +1320,7 @@ def intermediate_outputs(self) -> List[OutputParam]:
OutputParam(
self._image_latents_output_name,
type_hint=torch.Tensor,
- description="The latents representing the reference image",
- )
- ]
-
- @torch.no_grad()
- def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
- block_state = self.get_block_state(state)
-
- device = components._execution_device
- dtype = components.vae.dtype
-
- image = getattr(block_state, self._image_input_name)
-
- # Encode image into latents
- image_latents = encode_vae_image(
- image=image,
- vae=components.vae,
- generator=block_state.generator,
- device=device,
- dtype=dtype,
- latent_channels=components.num_channels_latents,
- )
- setattr(block_state, self._image_latents_output_name, image_latents)
-
- self.set_block_state(state, block_state)
-
- return components, state
-
-
-class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
- model_name = "qwenimage-edit-plus"
-
- @property
- def intermediate_outputs(self) -> List[OutputParam]:
- # Each reference image latent can have varied resolutions hence we return this as a list.
- return [
- OutputParam(
- self._image_latents_output_name,
- type_hint=List[torch.Tensor],
- description="The latents representing the reference image(s).",
+ description="The latents representing the reference image(s). Single tensor or list depending on input.",
)
]
@@ -1000,8 +1332,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
dtype = components.vae.dtype
image = getattr(block_state, self._image_input_name)
+ is_image_list = isinstance(image, list)
+ if not is_image_list:
+ image = [image]
- # Encode image into latents
+ # Handle both single image and list of images
image_latents = []
for img in image:
image_latents.append(
@@ -1014,6 +1349,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
latent_channels=components.num_channels_latents,
)
)
+ if not is_image_list:
+ image_latents = image_latents[0]
setattr(block_state, self._image_latents_output_name, image_latents)
@@ -1131,3 +1468,37 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
self.set_block_state(state, block_state)
return components, state
+
+
+# ====================
+# 6. PERMUTE LATENTS
+# ====================
+class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks):
+ """Permute image latents from VAE format to Layered format."""
+
+ model_name = "qwenimage-layered"
+
+ def __init__(self, input_name: str = "image_latents"):
+ self._input_name = input_name
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Permute {self._input_name} from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(self._input_name, required=True),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Permute: (B, C, 1, H, W) -> (B, 1, C, H, W)
+ latents = getattr(block_state, self._input_name)
+ setattr(block_state, self._input_name, latents.permute(0, 2, 1, 3, 4))
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py
index 6e656e484847..4a1cf3700c57 100644
--- a/src/diffusers/modular_pipelines/qwenimage/inputs.py
+++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py
@@ -19,7 +19,7 @@
from ...models import QwenImageMultiControlNetModel
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
-from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
def repeat_tensor_to_batch_size(
@@ -221,37 +221,16 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
-class QwenImageInputsDynamicStep(ModularPipelineBlocks):
- model_name = "qwenimage"
-
- def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
- """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
-
- This step handles multiple common tasks to prepare inputs for the denoising step:
- 1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
- 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
-
- This is a dynamic block that allows you to configure which inputs to process.
+class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
+ """Input step for QwenImage: update height/width, expand batch, patchify."""
- Args:
- image_latent_inputs (List[str], optional): Names of image latent tensors to process.
- These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
- list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
- additional_batch_inputs (List[str], optional):
- Names of additional conditional input tensors to expand batch size. These tensors will only have their
- batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
- Defaults to []. Examples: ["processed_mask_image"]
-
- Examples:
- # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
-
- # Configure to process multiple image latent inputs
- QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
+ model_name = "qwenimage"
- # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
- image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
- )
- """
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
@@ -263,14 +242,12 @@ def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additiona
@property
def description(self) -> str:
- # Functionality section
summary_section = (
"Input processing step that:\n"
- " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
- # Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
@@ -279,11 +256,16 @@ def description(self) -> str:
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
- # Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
@property
def inputs(self) -> List[InputParam]:
inputs = [
@@ -293,11 +275,9 @@ def inputs(self) -> List[InputParam]:
InputParam(name="width"),
]
- # Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
- # Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
@@ -306,26 +286,28 @@ def inputs(self) -> List[InputParam]:
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
- OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
- OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
- ]
-
- @property
- def expected_components(self) -> List[ComponentSpec]:
- return [
- ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ OutputParam(
+ name="image_height",
+ type_hint=int,
+ description="The image height calculated from the image latents dimension",
+ ),
+ OutputParam(
+ name="image_width",
+ type_hint=int,
+ description="The image width calculated from the image latents dimension",
+ ),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
- # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ # Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
- # 1. Calculate height/width from latents
+ # 1. Calculate height/width from latents and update if not provided
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
@@ -335,7 +317,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
if not hasattr(block_state, "image_width"):
block_state.image_width = width
- # 2. Patchify the image latent tensor
+ # 2. Patchify
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
# 3. Expand batch size
@@ -354,7 +336,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
if input_tensor is None:
continue
- # Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
@@ -368,63 +349,270 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state
-class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
+class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
+ """Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
+
model_name = "qwenimage-edit-plus"
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ summary_section = (
+ "Input processing step for Edit Plus that:\n"
+ " 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size\n"
+ " Height/width defaults to last image in the list."
+ )
+
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
- OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
- OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
+ OutputParam(
+ name="image_height",
+ type_hint=List[int],
+ description="The image heights calculated from the image latents dimension",
+ ),
+ OutputParam(
+ name="image_width",
+ type_hint=List[int],
+ description="The image widths calculated from the image latents dimension",
+ ),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
- # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ # Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
- # Each image latent can have different size in QwenImage Edit Plus.
+ is_list = isinstance(image_latent_tensor, list)
+ if not is_list:
+ image_latent_tensor = [image_latent_tensor]
+
image_heights = []
image_widths = []
packed_image_latent_tensors = []
- for img_latent_tensor in image_latent_tensor:
+ for i, img_latent_tensor in enumerate(image_latent_tensor):
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
image_heights.append(height)
image_widths.append(width)
- # 2. Patchify the image latent tensor
+ # 2. Patchify
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
# 3. Expand batch size
img_latent_tensor = repeat_tensor_to_batch_size(
- input_name=image_latent_input_name,
+ input_name=f"{image_latent_input_name}[{i}]",
input_tensor=img_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
packed_image_latent_tensors.append(img_latent_tensor)
+ # Concatenate all packed latents along dim=1
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
+
+ # Output lists of heights/widths
block_state.image_height = image_heights
block_state.image_width = image_widths
- setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
+ # Default height/width from last image
block_state.height = block_state.height or image_heights[-1]
block_state.width = block_state.width or image_widths[-1]
+ setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+# YiYi TODO: support define config default component from the ModularPipeline level.
+# it is same as QwenImageAdditionalInputsStep, but with layered pachifier.
+class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
+ """Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier."""
+
+ model_name = "qwenimage-layered"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ summary_section = (
+ "Input processing step for Layered that:\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ ]
+
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="image_height",
+ type_hint=int,
+ description="The image height calculated from the image latents dimension",
+ ),
+ OutputParam(
+ name="image_width",
+ type_hint=int,
+ description="The image width calculated from the image latents dimension",
+ ),
+ OutputParam(name="height", type_hint=int, description="The height of the image output"),
+ OutputParam(name="width", type_hint=int, description="The width of the image output"),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents and update if not provided
+ # Layered latents are (B, layers, C, H, W)
+ height = image_latent_tensor.shape[3] * components.vae_scale_factor
+ width = image_latent_tensor.shape[4] * components.vae_scale_factor
+ block_state.height = height
+ block_state.width = width
+
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify with layered pachifier
+ image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
- # Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
deleted file mode 100644
index dcce0cab5dd1..000000000000
--- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
+++ /dev/null
@@ -1,1113 +0,0 @@
-# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from ...utils import logging
-from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
-from ..modular_pipeline_utils import InsertableDict
-from .before_denoise import (
- QwenImageControlNetBeforeDenoiserStep,
- QwenImageCreateMaskLatentsStep,
- QwenImageEditPlusRoPEInputsStep,
- QwenImageEditRoPEInputsStep,
- QwenImagePrepareLatentsStep,
- QwenImagePrepareLatentsWithStrengthStep,
- QwenImageRoPEInputsStep,
- QwenImageSetTimestepsStep,
- QwenImageSetTimestepsWithStrengthStep,
-)
-from .decoders import (
- QwenImageAfterDenoiseStep,
- QwenImageDecoderStep,
- QwenImageInpaintProcessImagesOutputStep,
- QwenImageProcessImagesOutputStep,
-)
-from .denoise import (
- QwenImageControlNetDenoiseStep,
- QwenImageDenoiseStep,
- QwenImageEditDenoiseStep,
- QwenImageEditInpaintDenoiseStep,
- QwenImageInpaintControlNetDenoiseStep,
- QwenImageInpaintDenoiseStep,
- QwenImageLoopBeforeDenoiserControlNet,
-)
-from .encoders import (
- QwenImageControlNetVaeEncoderStep,
- QwenImageEditPlusProcessImagesInputStep,
- QwenImageEditPlusResizeDynamicStep,
- QwenImageEditPlusTextEncoderStep,
- QwenImageEditPlusVaeEncoderDynamicStep,
- QwenImageEditResizeDynamicStep,
- QwenImageEditTextEncoderStep,
- QwenImageInpaintProcessImagesInputStep,
- QwenImageProcessImagesInputStep,
- QwenImageTextEncoderStep,
- QwenImageVaeEncoderDynamicStep,
-)
-from .inputs import (
- QwenImageControlNetInputsStep,
- QwenImageEditPlusInputsDynamicStep,
- QwenImageInputsDynamicStep,
- QwenImageTextInputsStep,
-)
-
-
-logger = logging.get_logger(__name__)
-
-# 1. QwenImage
-
-## 1.1 QwenImage/text2image
-
-#### QwenImage/decode
-#### (standard decode step works for most tasks except for inpaint)
-QwenImageDecodeBlocks = InsertableDict(
- [
- ("decode", QwenImageDecoderStep()),
- ("postprocess", QwenImageProcessImagesOutputStep()),
- ]
-)
-
-
-class QwenImageDecodeStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageDecodeBlocks.values()
- block_names = QwenImageDecodeBlocks.keys()
-
- @property
- def description(self):
- return "Decode step that decodes the latents to images and postprocess the generated image."
-
-
-#### QwenImage/text2image presets
-TEXT2IMAGE_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageTextEncoderStep()),
- ("input", QwenImageTextInputsStep()),
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsStep()),
- ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
- ("denoise", QwenImageDenoiseStep()),
- ("after_denoise", QwenImageAfterDenoiseStep()),
- ("decode", QwenImageDecodeStep()),
- ]
-)
-
-
-## 1.2 QwenImage/inpaint
-
-#### QwenImage/inpaint vae encoder
-QwenImageInpaintVaeEncoderBlocks = InsertableDict(
- [
- (
- "preprocess",
- QwenImageInpaintProcessImagesInputStep,
- ), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
- ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
- ]
-)
-
-
-class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageInpaintVaeEncoderBlocks.values()
- block_names = QwenImageInpaintVaeEncoderBlocks.keys()
-
- @property
- def description(self) -> str:
- return (
- "This step is used for processing image and mask inputs for inpainting tasks. It:\n"
- " - Resizes the image to the target size, based on `height` and `width`.\n"
- " - Processes and updates `image` and `mask_image`.\n"
- " - Creates `image_latents`."
- )
-
-
-#### QwenImage/inpaint inputs
-QwenImageInpaintInputBlocks = InsertableDict(
- [
- ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
- (
- "additional_inputs",
- QwenImageInputsDynamicStep(
- image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
- ),
- ),
- ]
-)
-
-
-class QwenImageInpaintInputStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageInpaintInputBlocks.values()
- block_names = QwenImageInpaintInputBlocks.keys()
-
- @property
- def description(self):
- return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
- " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
- " - update height/width based `image_latents`, patchify `image_latents`."
-
-
-# QwenImage/inpaint prepare latents
-QwenImageInpaintPrepareLatentsBlocks = InsertableDict(
- [
- ("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()),
- ("create_mask_latents", QwenImageCreateMaskLatentsStep()),
- ]
-)
-
-
-class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageInpaintPrepareLatentsBlocks.values()
- block_names = QwenImageInpaintPrepareLatentsBlocks.keys()
-
- @property
- def description(self) -> str:
- return (
- "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
- " - Add noise to the image latents to create the latents input for the denoiser.\n"
- " - Create the pachified latents `mask` based on the processedmask image.\n"
- )
-
-
-#### QwenImage/inpaint decode
-QwenImageInpaintDecodeBlocks = InsertableDict(
- [
- ("decode", QwenImageDecoderStep()),
- ("postprocess", QwenImageInpaintProcessImagesOutputStep()),
- ]
-)
-
-
-class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageInpaintDecodeBlocks.values()
- block_names = QwenImageInpaintDecodeBlocks.keys()
-
- @property
- def description(self):
- return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
-
-
-#### QwenImage/inpaint presets
-INPAINT_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageTextEncoderStep()),
- ("vae_encoder", QwenImageInpaintVaeEncoderStep()),
- ("input", QwenImageInpaintInputStep()),
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
- ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
- ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
- ("denoise", QwenImageInpaintDenoiseStep()),
- ("after_denoise", QwenImageAfterDenoiseStep()),
- ("decode", QwenImageInpaintDecodeStep()),
- ]
-)
-
-
-## 1.3 QwenImage/img2img
-
-#### QwenImage/img2img vae encoder
-QwenImageImg2ImgVaeEncoderBlocks = InsertableDict(
- [
- ("preprocess", QwenImageProcessImagesInputStep()),
- ("encode", QwenImageVaeEncoderDynamicStep()),
- ]
-)
-
-
-class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
-
- block_classes = QwenImageImg2ImgVaeEncoderBlocks.values()
- block_names = QwenImageImg2ImgVaeEncoderBlocks.keys()
-
- @property
- def description(self) -> str:
- return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
-
-
-#### QwenImage/img2img inputs
-QwenImageImg2ImgInputBlocks = InsertableDict(
- [
- ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
- ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
- ]
-)
-
-
-class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageImg2ImgInputBlocks.values()
- block_names = QwenImageImg2ImgInputBlocks.keys()
-
- @property
- def description(self):
- return "Input step that prepares the inputs for the img2img denoising step. It:\n"
- " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
- " - update height/width based `image_latents`, patchify `image_latents`."
-
-
-#### QwenImage/img2img presets
-IMAGE2IMAGE_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageTextEncoderStep()),
- ("vae_encoder", QwenImageImg2ImgVaeEncoderStep()),
- ("input", QwenImageImg2ImgInputStep()),
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
- ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
- ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
- ("denoise", QwenImageDenoiseStep()),
- ("after_denoise", QwenImageAfterDenoiseStep()),
- ("decode", QwenImageDecodeStep()),
- ]
-)
-
-
-## 1.4 QwenImage/controlnet
-
-#### QwenImage/controlnet presets
-CONTROLNET_BLOCKS = InsertableDict(
- [
- ("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image
- ("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet
- (
- "controlnet_before_denoise",
- QwenImageControlNetBeforeDenoiserStep(),
- ), # before denoise step (after set_timesteps step)
- (
- "controlnet_denoise_loop_before",
- QwenImageLoopBeforeDenoiserControlNet(),
- ), # controlnet loop step (insert before the denoiseloop_denoiser)
- ]
-)
-
-
-## 1.5 QwenImage/auto encoders
-
-
-#### for inpaint and img2img tasks
-class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
- block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
- block_names = ["inpaint", "img2img"]
- block_trigger_inputs = ["mask_image", "image"]
-
- @property
- def description(self):
- return (
- "Vae encoder step that encode the image inputs into their latent representations.\n"
- + "This is an auto pipeline block.\n"
- + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
- + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
- + " - if `mask_image` or `image` is not provided, step will be skipped."
- )
-
-
-# for controlnet tasks
-class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
- block_classes = [QwenImageControlNetVaeEncoderStep]
- block_names = ["controlnet"]
- block_trigger_inputs = ["control_image"]
-
- @property
- def description(self):
- return (
- "Vae encoder step that encode the image inputs into their latent representations.\n"
- + "This is an auto pipeline block.\n"
- + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
- + " - if `control_image` is not provided, step will be skipped."
- )
-
-
-## 1.6 QwenImage/auto inputs
-
-
-# text2image/inpaint/img2img
-class QwenImageAutoInputStep(AutoPipelineBlocks):
- block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep]
- block_names = ["inpaint", "img2img", "text2image"]
- block_trigger_inputs = ["processed_mask_image", "image_latents", None]
-
- @property
- def description(self):
- return (
- "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
- " This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n"
- + " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n"
- + " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
- + " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
- )
-
-
-# controlnet
-class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks):
- block_classes = [QwenImageControlNetInputsStep]
- block_names = ["controlnet"]
- block_trigger_inputs = ["control_image_latents"]
-
- @property
- def description(self):
- return (
- "Controlnet input step that prepare the control_image_latents input.\n"
- + "This is an auto pipeline block.\n"
- + " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n"
- + " - if `control_image_latents` is not provided, step will be skipped."
- )
-
-
-## 1.7 QwenImage/auto before denoise step
-# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step
-
-# QwenImage/text2image before denoise
-QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict(
- [
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsStep()),
- ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
- ]
-)
-
-
-class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values()
- block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys()
-
- @property
- def description(self):
- return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task."
-
-
-# QwenImage/inpaint before denoise
-QwenImageInpaintBeforeDenoiseBlocks = InsertableDict(
- [
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
- ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
- ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
- ]
-)
-
-
-class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageInpaintBeforeDenoiseBlocks.values()
- block_names = QwenImageInpaintBeforeDenoiseBlocks.keys()
-
- @property
- def description(self):
- return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
-
-
-# QwenImage/img2img before denoise
-QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict(
- [
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
- ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
- ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
- ]
-)
-
-
-class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values()
- block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys()
-
- @property
- def description(self):
- return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
-
-
-# auto before_denoise step for text2image, inpaint, img2img tasks
-class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks):
- block_classes = [
- QwenImageInpaintBeforeDenoiseStep,
- QwenImageImg2ImgBeforeDenoiseStep,
- QwenImageText2ImageBeforeDenoiseStep,
- ]
- block_names = ["inpaint", "img2img", "text2image"]
- block_trigger_inputs = ["processed_mask_image", "image_latents", None]
-
- @property
- def description(self):
- return (
- "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
- + "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n"
- + " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
- + " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
- + " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
- )
-
-
-# auto before_denoise step for controlnet tasks
-class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks):
- block_classes = [QwenImageControlNetBeforeDenoiserStep]
- block_names = ["controlnet"]
- block_trigger_inputs = ["control_image_latents"]
-
- @property
- def description(self):
- return (
- "Controlnet before denoise step that prepare the controlnet input.\n"
- + "This is an auto pipeline block.\n"
- + " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n"
- + " - if `control_image_latents` is not provided, step will be skipped."
- )
-
-
-## 1.8 QwenImage/auto denoise
-
-
-# auto denoise step for controlnet tasks: works for all tasks with controlnet
-class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks):
- block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep]
- block_names = ["inpaint_denoise", "denoise"]
- block_trigger_inputs = ["mask", None]
-
- @property
- def description(self):
- return (
- "Controlnet step during the denoising process. \n"
- " This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n"
- + " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n"
- + " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n"
- )
-
-
-# auto denoise step for everything: works for all tasks with or without controlnet
-class QwenImageAutoDenoiseStep(AutoPipelineBlocks):
- block_classes = [
- QwenImageControlNetAutoDenoiseStep,
- QwenImageInpaintDenoiseStep,
- QwenImageDenoiseStep,
- ]
- block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
- block_trigger_inputs = ["control_image_latents", "mask", None]
-
- @property
- def description(self):
- return (
- "Denoise step that iteratively denoise the latents. \n"
- " This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n"
- + " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n"
- + " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n"
- + " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n"
- )
-
-
-## 1.9 QwenImage/auto decode
-# auto decode step for inpaint and text2image tasks
-
-
-class QwenImageAutoDecodeStep(AutoPipelineBlocks):
- block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
- block_names = ["inpaint_decode", "decode"]
- block_trigger_inputs = ["mask", None]
-
- @property
- def description(self):
- return (
- "Decode step that decode the latents into images. \n"
- " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
- + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
- + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
- )
-
-
-class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = [
- QwenImageAutoInputStep,
- QwenImageOptionalControlNetInputStep,
- QwenImageAutoBeforeDenoiseStep,
- QwenImageOptionalControlNetBeforeDenoiseStep,
- QwenImageAutoDenoiseStep,
- QwenImageAfterDenoiseStep,
- ]
- block_names = [
- "input",
- "controlnet_input",
- "before_denoise",
- "controlnet_before_denoise",
- "denoise",
- "after_denoise",
- ]
-
- @property
- def description(self):
- return (
- "Core step that performs the denoising process. \n"
- + " - `QwenImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
- + " - `QwenImageOptionalControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
- + " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
- + " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n"
- + " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
- + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
- + " - for image-to-image generation, you need to provide `image_latents`\n"
- + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
- + " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
- + " - for text-to-image generation, all you need to provide is prompt embeddings"
- )
-
-
-## 1.10 QwenImage/auto block & presets
-AUTO_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageTextEncoderStep()),
- ("vae_encoder", QwenImageAutoVaeEncoderStep()),
- ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
- ("denoise", QwenImageCoreDenoiseStep()),
- ("decode", QwenImageAutoDecodeStep()),
- ]
-)
-
-
-class QwenImageAutoBlocks(SequentialPipelineBlocks):
- model_name = "qwenimage"
-
- block_classes = AUTO_BLOCKS.values()
- block_names = AUTO_BLOCKS.keys()
-
- @property
- def description(self):
- return (
- "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
- + "- for image-to-image generation, you need to provide `image`\n"
- + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
- + "- to run the controlnet workflow, you need to provide `control_image`\n"
- + "- for text-to-image generation, all you need to provide is `prompt`"
- )
-
-
-# 2. QwenImage-Edit
-
-## 2.1 QwenImage-Edit/edit
-
-#### QwenImage-Edit/edit vl encoder: take both image and text prompts
-QwenImageEditVLEncoderBlocks = InsertableDict(
- [
- ("resize", QwenImageEditResizeDynamicStep()),
- ("encode", QwenImageEditTextEncoderStep()),
- ]
-)
-
-
-class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageEditVLEncoderBlocks.values()
- block_names = QwenImageEditVLEncoderBlocks.keys()
-
- @property
- def description(self) -> str:
- return "QwenImage-Edit VL encoder step that encode the image an text prompts together."
-
-
-#### QwenImage-Edit/edit vae encoder
-QwenImageEditVaeEncoderBlocks = InsertableDict(
- [
- ("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step
- ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image
- ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
- ]
-)
-
-
-class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageEditVaeEncoderBlocks.values()
- block_names = QwenImageEditVaeEncoderBlocks.keys()
-
- @property
- def description(self) -> str:
- return "Vae encoder step that encode the image inputs into their latent representations."
-
-
-#### QwenImage-Edit/edit input
-QwenImageEditInputBlocks = InsertableDict(
- [
- ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
- ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
- ]
-)
-
-
-class QwenImageEditInputStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageEditInputBlocks.values()
- block_names = QwenImageEditInputBlocks.keys()
-
- @property
- def description(self):
- return "Input step that prepares the inputs for the edit denoising step. It:\n"
- " - make sure the text embeddings have consistent batch size as well as the additional inputs: \n"
- " - `image_latents`.\n"
- " - update height/width based `image_latents`, patchify `image_latents`."
-
-
-#### QwenImage/edit presets
-EDIT_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageEditVLEncoderStep()),
- ("vae_encoder", QwenImageEditVaeEncoderStep()),
- ("input", QwenImageEditInputStep()),
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsStep()),
- ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
- ("denoise", QwenImageEditDenoiseStep()),
- ("after_denoise", QwenImageAfterDenoiseStep()),
- ("decode", QwenImageDecodeStep()),
- ]
-)
-
-
-## 2.2 QwenImage-Edit/edit inpaint
-
-#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step
-QwenImageEditInpaintVaeEncoderBlocks = InsertableDict(
- [
- ("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image
- (
- "preprocess",
- QwenImageInpaintProcessImagesInputStep,
- ), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
- (
- "encode",
- QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
- ), # processed_image -> image_latents
- ]
-)
-
-
-class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageEditInpaintVaeEncoderBlocks.values()
- block_names = QwenImageEditInpaintVaeEncoderBlocks.keys()
-
- @property
- def description(self) -> str:
- return (
- "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
- " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
- " - process the resized image and mask image.\n"
- " - create image latents."
- )
-
-
-#### QwenImage-Edit/edit inpaint presets
-EDIT_INPAINT_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageEditVLEncoderStep()),
- ("vae_encoder", QwenImageEditInpaintVaeEncoderStep()),
- ("input", QwenImageInpaintInputStep()),
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
- ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
- ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
- ("denoise", QwenImageEditInpaintDenoiseStep()),
- ("after_denoise", QwenImageAfterDenoiseStep()),
- ("decode", QwenImageInpaintDecodeStep()),
- ]
-)
-
-
-## 2.3 QwenImage-Edit/auto encoders
-
-
-class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
- block_classes = [
- QwenImageEditInpaintVaeEncoderStep,
- QwenImageEditVaeEncoderStep,
- ]
- block_names = ["edit_inpaint", "edit"]
- block_trigger_inputs = ["mask_image", "image"]
-
- @property
- def description(self):
- return (
- "Vae encoder step that encode the image inputs into their latent representations. \n"
- " This is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
- + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
- + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
- + " - if `mask_image` or `image` is not provided, step will be skipped."
- )
-
-
-## 2.4 QwenImage-Edit/auto inputs
-class QwenImageEditAutoInputStep(AutoPipelineBlocks):
- block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
- block_names = ["edit_inpaint", "edit"]
- block_trigger_inputs = ["processed_mask_image", "image_latents"]
-
- @property
- def description(self):
- return (
- "Input step that prepares the inputs for the edit denoising step.\n"
- + " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
- + " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
- + " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n"
- + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
- )
-
-
-## 2.5 QwenImage-Edit/auto before denoise
-# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step
-
-#### QwenImage-Edit/edit before denoise
-QwenImageEditBeforeDenoiseBlocks = InsertableDict(
- [
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsStep()),
- ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
- ]
-)
-
-
-class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageEditBeforeDenoiseBlocks.values()
- block_names = QwenImageEditBeforeDenoiseBlocks.keys()
-
- @property
- def description(self):
- return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
-
-
-#### QwenImage-Edit/edit inpaint before denoise
-QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict(
- [
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
- ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
- ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
- ]
-)
-
-
-class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values()
- block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys()
-
- @property
- def description(self):
- return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task."
-
-
-# auto before_denoise step for edit and edit_inpaint tasks
-class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks):
- model_name = "qwenimage-edit"
- block_classes = [
- QwenImageEditInpaintBeforeDenoiseStep,
- QwenImageEditBeforeDenoiseStep,
- ]
- block_names = ["edit_inpaint", "edit"]
- block_trigger_inputs = ["processed_mask_image", "image_latents"]
-
- @property
- def description(self):
- return (
- "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
- + "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n"
- + " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
- + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
- + " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped."
- )
-
-
-## 2.6 QwenImage-Edit/auto denoise
-
-
-class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
- model_name = "qwenimage-edit"
-
- block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep]
- block_names = ["inpaint_denoise", "denoise"]
- block_trigger_inputs = ["processed_mask_image", "image_latents"]
-
- @property
- def description(self):
- return (
- "Denoise step that iteratively denoise the latents. \n"
- + "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n"
- + " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
- + " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
- + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
- )
-
-
-## 2.7 QwenImage-Edit/auto blocks & presets
-
-
-class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage-edit"
- block_classes = [
- QwenImageEditAutoInputStep,
- QwenImageEditAutoBeforeDenoiseStep,
- QwenImageEditAutoDenoiseStep,
- QwenImageAfterDenoiseStep,
- ]
- block_names = ["input", "before_denoise", "denoise", "after_denoise"]
-
- @property
- def description(self):
- return (
- "Core step that performs the denoising process. \n"
- + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
- + " - `QwenImageEditAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
- + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
- + "This step support edit (img2img) and edit inpainting workflow for QwenImage Edit:\n"
- + " - When `processed_mask_image` is provided, it will be used for edit inpainting task.\n"
- + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
- )
-
-
-EDIT_AUTO_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageEditVLEncoderStep()),
- ("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
- ("denoise", QwenImageEditCoreDenoiseStep()),
- ("decode", QwenImageAutoDecodeStep()),
- ]
-)
-
-
-class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
- model_name = "qwenimage-edit"
- block_classes = EDIT_AUTO_BLOCKS.values()
- block_names = EDIT_AUTO_BLOCKS.keys()
-
- @property
- def description(self):
- return (
- "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
- + "- for edit (img2img) generation, you need to provide `image`\n"
- + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
- )
-
-
-#################### QwenImage Edit Plus #####################
-
-# 3. QwenImage-Edit Plus
-
-## 3.1 QwenImage-Edit Plus / edit
-
-#### QwenImage-Edit Plus vl encoder: take both image and text prompts
-QwenImageEditPlusVLEncoderBlocks = InsertableDict(
- [
- ("resize", QwenImageEditPlusResizeDynamicStep()),
- ("encode", QwenImageEditPlusTextEncoderStep()),
- ]
-)
-
-
-class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
- model_name = "qwenimage"
- block_classes = QwenImageEditPlusVLEncoderBlocks.values()
- block_names = QwenImageEditPlusVLEncoderBlocks.keys()
-
- @property
- def description(self) -> str:
- return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together."
-
-
-#### QwenImage-Edit Plus vae encoder
-QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
- [
- ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
- ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
- ("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents
- ]
-)
-
-
-class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
- model_name = "qwenimage-edit-plus"
- block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
- block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
-
- @property
- def description(self) -> str:
- return "Vae encoder step that encode the image inputs into their latent representations."
-
-
-#### QwenImage Edit Plus input blocks
-QwenImageEditPlusInputBlocks = InsertableDict(
- [
- ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
- (
- "additional_inputs",
- QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]),
- ),
- ]
-)
-
-
-class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
- model_name = "qwenimage-edit-plus"
- block_classes = QwenImageEditPlusInputBlocks.values()
- block_names = QwenImageEditPlusInputBlocks.keys()
-
-
-#### QwenImage Edit Plus presets
-EDIT_PLUS_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageEditPlusVLEncoderStep()),
- ("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
- ("input", QwenImageEditPlusInputStep()),
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsStep()),
- ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
- ("denoise", QwenImageEditDenoiseStep()),
- ("after_denoise", QwenImageAfterDenoiseStep()),
- ("decode", QwenImageDecodeStep()),
- ]
-)
-
-
-QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict(
- [
- ("prepare_latents", QwenImagePrepareLatentsStep()),
- ("set_timesteps", QwenImageSetTimestepsStep()),
- ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
- ]
-)
-
-
-class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage-edit-plus"
- block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values()
- block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys()
-
- @property
- def description(self):
- return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
-
-
-# auto before_denoise step for edit tasks
-class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
- model_name = "qwenimage-edit-plus"
- block_classes = [QwenImageEditPlusBeforeDenoiseStep]
- block_names = ["edit"]
- block_trigger_inputs = ["image_latents"]
-
- @property
- def description(self):
- return (
- "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
- + "This is an auto pipeline block that works for edit (img2img) task.\n"
- + " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
- + " - if `image_latents` is not provided, step will be skipped."
- )
-
-
-## 3.2 QwenImage-Edit Plus/auto encoders
-
-
-class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
- block_classes = [QwenImageEditPlusVaeEncoderStep]
- block_names = ["edit"]
- block_trigger_inputs = ["image"]
-
- @property
- def description(self):
- return (
- "Vae encoder step that encode the image inputs into their latent representations. \n"
- " This is an auto pipeline block that works for edit task.\n"
- + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n"
- + " - if `image` is not provided, step will be skipped."
- )
-
-
-## 3.3 QwenImage-Edit/auto blocks & presets
-
-
-class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks):
- block_classes = [QwenImageEditPlusInputStep]
- block_names = ["edit"]
- block_trigger_inputs = ["image_latents"]
-
- @property
- def description(self):
- return (
- "Input step that prepares the inputs for the edit denoising step.\n"
- + " It is an auto pipeline block that works for edit task.\n"
- + " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n"
- + " - if `image_latents` is not provided, step will be skipped."
- )
-
-
-class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
- model_name = "qwenimage-edit-plus"
- block_classes = [
- QwenImageEditPlusAutoInputStep,
- QwenImageEditPlusAutoBeforeDenoiseStep,
- QwenImageEditAutoDenoiseStep,
- QwenImageAfterDenoiseStep,
- ]
- block_names = ["input", "before_denoise", "denoise", "after_denoise"]
-
- @property
- def description(self):
- return (
- "Core step that performs the denoising process. \n"
- + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
- + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
- + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
- + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n"
- + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
- )
-
-
-EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
- [
- ("text_encoder", QwenImageEditPlusVLEncoderStep()),
- ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()),
- ("denoise", QwenImageEditPlusCoreDenoiseStep()),
- ("decode", QwenImageAutoDecodeStep()),
- ]
-)
-
-
-class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
- model_name = "qwenimage-edit-plus"
- block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
- block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
-
- @property
- def description(self):
- return (
- "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n"
- + "- for edit (img2img) generation, you need to provide `image`\n"
- )
-
-
-# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus
-
-
-ALL_BLOCKS = {
- "text2image": TEXT2IMAGE_BLOCKS,
- "img2img": IMAGE2IMAGE_BLOCKS,
- "edit": EDIT_BLOCKS,
- "edit_inpaint": EDIT_INPAINT_BLOCKS,
- "edit_plus": EDIT_PLUS_BLOCKS,
- "inpaint": INPAINT_BLOCKS,
- "controlnet": CONTROLNET_BLOCKS,
- "auto": AUTO_BLOCKS,
- "edit_auto": EDIT_AUTO_BLOCKS,
- "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS,
-}
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py
new file mode 100644
index 000000000000..ebe0bbbd75ba
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py
@@ -0,0 +1,488 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import PIL.Image
+import torch
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict, OutputParam
+from .before_denoise import (
+ QwenImageControlNetBeforeDenoiserStep,
+ QwenImageCreateMaskLatentsStep,
+ QwenImagePrepareLatentsStep,
+ QwenImagePrepareLatentsWithStrengthStep,
+ QwenImageRoPEInputsStep,
+ QwenImageSetTimestepsStep,
+ QwenImageSetTimestepsWithStrengthStep,
+)
+from .decoders import (
+ QwenImageAfterDenoiseStep,
+ QwenImageDecoderStep,
+ QwenImageInpaintProcessImagesOutputStep,
+ QwenImageProcessImagesOutputStep,
+)
+from .denoise import (
+ QwenImageControlNetDenoiseStep,
+ QwenImageDenoiseStep,
+ QwenImageInpaintControlNetDenoiseStep,
+ QwenImageInpaintDenoiseStep,
+)
+from .encoders import (
+ QwenImageControlNetVaeEncoderStep,
+ QwenImageInpaintProcessImagesInputStep,
+ QwenImageProcessImagesInputStep,
+ QwenImageTextEncoderStep,
+ QwenImageVaeEncoderStep,
+)
+from .inputs import (
+ QwenImageAdditionalInputsStep,
+ QwenImageControlNetInputsStep,
+ QwenImageTextInputsStep,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+# ====================
+# 1. VAE ENCODER
+# ====================
+
+
+class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()]
+ block_names = ["preprocess", "encode"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for inpainting tasks. It:\n"
+ " - Resizes the image to the target size, based on `height` and `width`.\n"
+ " - Processes and updates `image` and `mask_image`.\n"
+ " - Creates `image_latents`."
+ )
+
+
+class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()]
+ block_names = ["preprocess", "encode"]
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+# Auto VAE encoder
+class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
+ block_names = ["inpaint", "img2img"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
+ + " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+# optional controlnet vae encoder
+class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetVaeEncoderStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
+ + " - if `control_image` is not provided, step will be skipped."
+ )
+
+
+# ====================
+# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
+# ====================
+
+
+# assemble input steps
+class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
+ block_names = ["text_inputs", "additional_inputs"]
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+class QwenImageInpaintInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageTextInputsStep(),
+ QwenImageAdditionalInputsStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ ),
+ ]
+ block_names = ["text_inputs", "additional_inputs"]
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+# assemble prepare latents steps
+class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
+ block_names = ["add_noise_to_latents", "create_mask_latents"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
+ " - Add noise to the image latents to create the latents input for the denoiser.\n"
+ " - Create the pachified latents `mask` based on the processedmask image.\n"
+ )
+
+
+# assemble denoising steps
+
+
+# Qwen Image (text2image)
+class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageTextInputsStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsStep(),
+ QwenImageRoPEInputsStep(),
+ QwenImageDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_rope_inputs",
+ "denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
+
+
+# Qwen Image (inpainting)
+class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageInpaintInputStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsWithStrengthStep(),
+ QwenImageInpaintPrepareLatentsStep(),
+ QwenImageRoPEInputsStep(),
+ QwenImageInpaintDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_inpaint_latents",
+ "prepare_rope_inputs",
+ "denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
+
+
+# Qwen Image (image2image)
+class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageImg2ImgInputStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsWithStrengthStep(),
+ QwenImagePrepareLatentsWithStrengthStep(),
+ QwenImageRoPEInputsStep(),
+ QwenImageDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_img2img_latents",
+ "prepare_rope_inputs",
+ "denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
+
+
+# Qwen Image (text2image) with controlnet
+class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageTextInputsStep(),
+ QwenImageControlNetInputsStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsStep(),
+ QwenImageRoPEInputsStep(),
+ QwenImageControlNetBeforeDenoiserStep(),
+ QwenImageControlNetDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "controlnet_input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_rope_inputs",
+ "controlnet_before_denoise",
+ "controlnet_denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
+
+
+# Qwen Image (inpainting) with controlnet
+class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageInpaintInputStep(),
+ QwenImageControlNetInputsStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsWithStrengthStep(),
+ QwenImageInpaintPrepareLatentsStep(),
+ QwenImageRoPEInputsStep(),
+ QwenImageControlNetBeforeDenoiserStep(),
+ QwenImageInpaintControlNetDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "controlnet_input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_inpaint_latents",
+ "prepare_rope_inputs",
+ "controlnet_before_denoise",
+ "controlnet_denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
+
+
+# Qwen Image (image2image) with controlnet
+class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageImg2ImgInputStep(),
+ QwenImageControlNetInputsStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsWithStrengthStep(),
+ QwenImagePrepareLatentsWithStrengthStep(),
+ QwenImageRoPEInputsStep(),
+ QwenImageControlNetBeforeDenoiserStep(),
+ QwenImageControlNetDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "controlnet_input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_img2img_latents",
+ "prepare_rope_inputs",
+ "controlnet_before_denoise",
+ "controlnet_denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
+
+
+# Auto denoise step for QwenImage
+class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
+ block_classes = [
+ QwenImageCoreDenoiseStep,
+ QwenImageInpaintCoreDenoiseStep,
+ QwenImageImg2ImgCoreDenoiseStep,
+ QwenImageControlNetCoreDenoiseStep,
+ QwenImageControlNetInpaintCoreDenoiseStep,
+ QwenImageControlNetImg2ImgCoreDenoiseStep,
+ ]
+ block_names = [
+ "text2image",
+ "inpaint",
+ "img2img",
+ "controlnet_text2image",
+ "controlnet_inpaint",
+ "controlnet_img2img",
+ ]
+ block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"]
+ default_block_name = "text2image"
+
+ def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None):
+ if control_image_latents is not None:
+ if processed_mask_image is not None:
+ return "controlnet_inpaint"
+ elif image_latents is not None:
+ return "controlnet_img2img"
+ else:
+ return "controlnet_text2image"
+ else:
+ if processed_mask_image is not None:
+ return "inpaint"
+ elif image_latents is not None:
+ return "img2img"
+ else:
+ return "text2image"
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n"
+ + " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n"
+ + " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n"
+ + " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n"
+ + " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n"
+ + " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n"
+ + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
+ + " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings"
+ )
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(
+ name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
+ ),
+ ]
+
+
+# ====================
+# 3. DECODE
+# ====================
+
+
+# standard decode step works for most tasks except for inpaint
+class QwenImageDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
+ block_names = ["decode", "postprocess"]
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image."
+
+
+# Inpaint decode step
+class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
+ block_names = ["decode", "postprocess"]
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
+
+
+# Auto decode step for QwenImage
+class QwenImageAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
+ block_names = ["inpaint_decode", "decode"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Decode step that decode the latents into images. \n"
+ " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
+ + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
+ )
+
+
+# ====================
+# 4. AUTO BLOCKS & PRESETS
+# ====================
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageAutoVaeEncoderStep()),
+ ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
+ ("denoise", QwenImageAutoCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = AUTO_BLOCKS.values()
+ block_names = AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ + "- for image-to-image generation, you need to provide `image`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`"
+ )
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
+ ]
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py
new file mode 100644
index 000000000000..2683e64080bf
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py
@@ -0,0 +1,353 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+import PIL.Image
+import torch
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict, OutputParam
+from .before_denoise import (
+ QwenImageCreateMaskLatentsStep,
+ QwenImageEditRoPEInputsStep,
+ QwenImagePrepareLatentsStep,
+ QwenImagePrepareLatentsWithStrengthStep,
+ QwenImageSetTimestepsStep,
+ QwenImageSetTimestepsWithStrengthStep,
+)
+from .decoders import (
+ QwenImageAfterDenoiseStep,
+ QwenImageDecoderStep,
+ QwenImageInpaintProcessImagesOutputStep,
+ QwenImageProcessImagesOutputStep,
+)
+from .denoise import (
+ QwenImageEditDenoiseStep,
+ QwenImageEditInpaintDenoiseStep,
+)
+from .encoders import (
+ QwenImageEditInpaintProcessImagesInputStep,
+ QwenImageEditProcessImagesInputStep,
+ QwenImageEditResizeStep,
+ QwenImageEditTextEncoderStep,
+ QwenImageVaeEncoderStep,
+)
+from .inputs import (
+ QwenImageAdditionalInputsStep,
+ QwenImageTextInputsStep,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+# ====================
+# 1. TEXT ENCODER
+# ====================
+
+
+class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
+ """VL encoder that takes both image and text prompts."""
+
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditResizeStep(),
+ QwenImageEditTextEncoderStep(),
+ ]
+ block_names = ["resize", "encode"]
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit VL encoder step that encode the image and text prompts together."
+
+
+# ====================
+# 2. VAE ENCODER
+# ====================
+
+
+# Edit VAE encoder
+class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditResizeStep(),
+ QwenImageEditProcessImagesInputStep(),
+ QwenImageVaeEncoderStep(),
+ ]
+ block_names = ["resize", "preprocess", "encode"]
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+# Edit Inpaint VAE encoder
+class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditResizeStep(),
+ QwenImageEditInpaintProcessImagesInputStep(),
+ QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
+ ]
+ block_names = ["resize", "preprocess", "encode"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
+ " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
+ " - process the resized image and mask image.\n"
+ " - create image latents."
+ )
+
+
+# Auto VAE encoder
+class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
+ " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+# ====================
+# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
+# ====================
+
+
+# assemble input steps
+class QwenImageEditInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageTextInputsStep(),
+ QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
+ ]
+ block_names = ["text_inputs", "additional_inputs"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the edit denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+ )
+
+
+class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageTextInputsStep(),
+ QwenImageAdditionalInputsStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ ),
+ ]
+ block_names = ["text_inputs", "additional_inputs"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the edit inpaint denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+ )
+
+
+# assemble prepare latents steps
+class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
+ block_names = ["add_noise_to_latents", "create_mask_latents"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n"
+ " - Add noise to the image latents to create the latents input for the denoiser.\n"
+ " - Create the patchified latents `mask` based on the processed mask image.\n"
+ )
+
+
+# Qwen Image Edit (image2image) core denoise step
+class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditInputStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsStep(),
+ QwenImageEditRoPEInputsStep(),
+ QwenImageEditDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_rope_inputs",
+ "denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
+
+
+# Qwen Image Edit (inpainting) core denoise step
+class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditInpaintInputStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsWithStrengthStep(),
+ QwenImageEditInpaintPrepareLatentsStep(),
+ QwenImageEditRoPEInputsStep(),
+ QwenImageEditInpaintDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_inpaint_latents",
+ "prepare_rope_inputs",
+ "denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Core denoising workflow for QwenImage-Edit edit inpaint task."
+
+
+# Auto core denoise step for QwenImage Edit
+class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditInpaintCoreDenoiseStep,
+ QwenImageEditCoreDenoiseStep,
+ ]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+ default_block_name = "edit"
+
+ def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]:
+ if processed_mask_image is not None:
+ return "edit_inpaint"
+ elif image_latents is not None:
+ return "edit"
+ return None
+
+ @property
+ def description(self):
+ return (
+ "Auto core denoising step that selects the appropriate workflow based on inputs.\n"
+ " - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n"
+ " - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n"
+ "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
+ )
+
+
+# ====================
+# 4. DECODE
+# ====================
+
+
+# Decode step (standard)
+class QwenImageEditDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
+ block_names = ["decode", "postprocess"]
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image."
+
+
+# Inpaint decode step
+class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
+ block_names = ["decode", "postprocess"]
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image."
+
+
+# Auto decode step
+class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep]
+ block_names = ["inpaint_decode", "decode"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Decode step that decode the latents into images.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ " - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
+ )
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(
+ name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
+ ),
+ ]
+
+
+# ====================
+# 5. AUTO BLOCKS & PRESETS
+# ====================
+
+EDIT_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
+ ("denoise", QwenImageEditAutoCoreDenoiseStep()),
+ ("decode", QwenImageEditAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = EDIT_AUTO_BLOCKS.values()
+ block_names = EDIT_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
+ "- for edit (img2img) generation, you need to provide `image`\n"
+ "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
+ )
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
+ ]
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py
new file mode 100644
index 000000000000..99c5b109bf38
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py
@@ -0,0 +1,200 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import PIL.Image
+import torch
+
+from ...utils import logging
+from ..modular_pipeline import SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict, OutputParam
+from .before_denoise import (
+ QwenImageEditPlusRoPEInputsStep,
+ QwenImagePrepareLatentsStep,
+ QwenImageSetTimestepsStep,
+)
+from .decoders import (
+ QwenImageAfterDenoiseStep,
+ QwenImageDecoderStep,
+ QwenImageProcessImagesOutputStep,
+)
+from .denoise import (
+ QwenImageEditDenoiseStep,
+)
+from .encoders import (
+ QwenImageEditPlusProcessImagesInputStep,
+ QwenImageEditPlusResizeStep,
+ QwenImageEditPlusTextEncoderStep,
+ QwenImageVaeEncoderStep,
+)
+from .inputs import (
+ QwenImageEditPlusAdditionalInputsStep,
+ QwenImageTextInputsStep,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+# ====================
+# 1. TEXT ENCODER
+# ====================
+
+
+class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
+ """VL encoder that takes both image and text prompts. Uses 384x384 target area."""
+
+ model_name = "qwenimage-edit-plus"
+ block_classes = [
+ QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"),
+ QwenImageEditPlusTextEncoderStep(),
+ ]
+ block_names = ["resize", "encode"]
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together."
+
+
+# ====================
+# 2. VAE ENCODER
+# ====================
+
+
+class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
+ """VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
+
+ model_name = "qwenimage-edit-plus"
+ block_classes = [
+ QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"),
+ QwenImageEditPlusProcessImagesInputStep(),
+ QwenImageVaeEncoderStep(),
+ ]
+ block_names = ["resize", "preprocess", "encode"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "VAE encoder step that encodes image inputs into latent representations.\n"
+ "Each image is resized independently based on its own aspect ratio to 1024x1024 target area."
+ )
+
+
+# ====================
+# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
+# ====================
+
+
+# assemble input steps
+class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [
+ QwenImageTextInputsStep(),
+ QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
+ ]
+ block_names = ["text_inputs", "additional_inputs"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the Edit Plus denoising step. It:\n"
+ " - Standardizes text embeddings batch size.\n"
+ " - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n"
+ " - Outputs lists of image_height/image_width for RoPE calculation.\n"
+ " - Defaults height/width from last image in the list."
+ )
+
+
+# Qwen Image Edit Plus (image2image) core denoise step
+class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [
+ QwenImageEditPlusInputStep(),
+ QwenImagePrepareLatentsStep(),
+ QwenImageSetTimestepsStep(),
+ QwenImageEditPlusRoPEInputsStep(),
+ QwenImageEditDenoiseStep(),
+ QwenImageAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_rope_inputs",
+ "denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(
+ name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
+ ),
+ ]
+
+
+# ====================
+# 4. DECODE
+# ====================
+
+
+class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
+ block_names = ["decode", "postprocess"]
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocesses the generated image."
+
+
+# ====================
+# 5. AUTO BLOCKS & PRESETS
+# ====================
+
+EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditPlusVLEncoderStep()),
+ ("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
+ ("denoise", QwenImageEditPlusCoreDenoiseStep()),
+ ("decode", QwenImageEditPlusDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
+ block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n"
+ "- `image` is required input (can be single image or list of images).\n"
+ "- Each image is resized independently based on its own aspect ratio.\n"
+ "- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
+ )
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
+ ]
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py
new file mode 100644
index 000000000000..63ee36df5112
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py
@@ -0,0 +1,178 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List
+
+import PIL.Image
+import torch
+
+from ...utils import logging
+from ..modular_pipeline import SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict, OutputParam
+from .before_denoise import (
+ QwenImageLayeredPrepareLatentsStep,
+ QwenImageLayeredRoPEInputsStep,
+ QwenImageLayeredSetTimestepsStep,
+)
+from .decoders import (
+ QwenImageLayeredAfterDenoiseStep,
+ QwenImageLayeredDecoderStep,
+)
+from .denoise import (
+ QwenImageLayeredDenoiseStep,
+)
+from .encoders import (
+ QwenImageEditProcessImagesInputStep,
+ QwenImageLayeredGetImagePromptStep,
+ QwenImageLayeredPermuteLatentsStep,
+ QwenImageLayeredResizeStep,
+ QwenImageTextEncoderStep,
+ QwenImageVaeEncoderStep,
+)
+from .inputs import (
+ QwenImageLayeredAdditionalInputsStep,
+ QwenImageTextInputsStep,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+# ====================
+# 1. TEXT ENCODER
+# ====================
+
+
+class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
+ """Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
+
+ model_name = "qwenimage-layered"
+ block_classes = [
+ QwenImageLayeredResizeStep(),
+ QwenImageLayeredGetImagePromptStep(),
+ QwenImageTextEncoderStep(),
+ ]
+ block_names = ["resize", "get_image_prompt", "encode"]
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided."
+
+
+# ====================
+# 2. VAE ENCODER
+# ====================
+
+
+# Edit VAE encoder
+class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-layered"
+ block_classes = [
+ QwenImageLayeredResizeStep(),
+ QwenImageEditProcessImagesInputStep(),
+ QwenImageVaeEncoderStep(),
+ QwenImageLayeredPermuteLatentsStep(),
+ ]
+ block_names = ["resize", "preprocess", "encode", "permute"]
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+# ====================
+# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
+# ====================
+
+
+# assemble input steps
+class QwenImageLayeredInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-layered"
+ block_classes = [
+ QwenImageTextInputsStep(),
+ QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
+ ]
+ block_names = ["text_inputs", "additional_inputs"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the layered denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+ )
+
+
+# Qwen Image Layered (image2image) core denoise step
+class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-layered"
+ block_classes = [
+ QwenImageLayeredInputStep(),
+ QwenImageLayeredPrepareLatentsStep(),
+ QwenImageLayeredSetTimestepsStep(),
+ QwenImageLayeredRoPEInputsStep(),
+ QwenImageLayeredDenoiseStep(),
+ QwenImageLayeredAfterDenoiseStep(),
+ ]
+ block_names = [
+ "input",
+ "prepare_latents",
+ "set_timesteps",
+ "prepare_rope_inputs",
+ "denoise",
+ "after_denoise",
+ ]
+
+ @property
+ def description(self):
+ return "Core denoising workflow for QwenImage-Layered img2img task."
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(
+ name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
+ ),
+ ]
+
+
+# ====================
+# 4. AUTO BLOCKS & PRESETS
+# ====================
+
+LAYERED_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageLayeredTextEncoderStep()),
+ ("vae_encoder", QwenImageLayeredVaeEncoderStep()),
+ ("denoise", QwenImageLayeredCoreDenoiseStep()),
+ ("decode", QwenImageLayeredDecoderStep()),
+ ]
+)
+
+
+class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-layered"
+ block_classes = LAYERED_AUTO_BLOCKS.values()
+ block_names = LAYERED_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered."
+
+ @property
+ def outputs(self):
+ return [
+ OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
+ ]
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
index 59e1a13a5db2..892435989d00 100644
--- a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
@@ -90,6 +90,88 @@ def unpack_latents(self, latents, height, width, vae_scale_factor=8):
return latents
+class QwenImageLayeredPachifier(ConfigMixin):
+ """
+ A class to pack and unpack latents for QwenImage Layered.
+
+ Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W).
+ """
+
+ config_name = "config.json"
+
+ @register_to_config
+ def __init__(self, patch_size: int = 2):
+ super().__init__()
+
+ def pack_latents(self, latents):
+ """
+ Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4).
+ """
+
+ if latents.ndim != 5:
+ raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}")
+
+ batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape
+ patch_size = self.config.patch_size
+
+ if latent_height % patch_size != 0 or latent_width % patch_size != 0:
+ raise ValueError(
+ f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
+ )
+
+ latents = latents.view(
+ batch_size,
+ layers,
+ num_channels_latents,
+ latent_height // patch_size,
+ patch_size,
+ latent_width // patch_size,
+ patch_size,
+ )
+ latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
+ latents = latents.reshape(
+ batch_size,
+ layers * (latent_height // patch_size) * (latent_width // patch_size),
+ num_channels_latents * patch_size * patch_size,
+ )
+ return latents
+
+ def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8):
+ """
+ Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W).
+ """
+
+ if latents.ndim != 3:
+ raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
+
+ batch_size, _, channels = latents.shape
+ patch_size = self.config.patch_size
+
+ height = patch_size * (int(height) // (vae_scale_factor * patch_size))
+ width = patch_size * (int(width) // (vae_scale_factor * patch_size))
+
+ latents = latents.view(
+ batch_size,
+ layers + 1,
+ height // patch_size,
+ width // patch_size,
+ channels // (patch_size * patch_size),
+ patch_size,
+ patch_size,
+ )
+ latents = latents.permute(0, 1, 4, 2, 5, 3, 6)
+ latents = latents.reshape(
+ batch_size,
+ layers + 1,
+ channels // (patch_size * patch_size),
+ height,
+ width,
+ )
+ latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w)
+
+ return latents
+
+
class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage.
@@ -203,3 +285,13 @@ class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline):
"""
default_blocks_name = "QwenImageEditPlusAutoBlocks"
+
+
+class QwenImageLayeredModularPipeline(QwenImageModularPipeline):
+ """
+ A ModularPipeline for QwenImage-Layered.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageLayeredAutoBlocks"
diff --git a/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py b/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py
new file mode 100644
index 000000000000..8e7beb555760
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py
@@ -0,0 +1,121 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Prompt templates for QwenImage pipelines.
+
+This module centralizes all prompt templates used across different QwenImage pipeline variants:
+- QwenImage (base): Text-only encoding for text-to-image generation
+- QwenImage Edit: VL encoding with single image for image editing
+- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing
+- QwenImage Layered: Auto-captioning for image decomposition
+"""
+
+# ============================================
+# QwenImage Base (text-only encoding)
+# ============================================
+# Used for text-to-image generation where only text prompt is encoded
+
+QWENIMAGE_PROMPT_TEMPLATE = (
+ "<|im_start|>system\n"
+ "Describe the image by detailing the color, shape, size, texture, quantity, text, "
+ "spatial relationships of the objects and background:<|im_end|>\n"
+ "<|im_start|>user\n{}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+)
+QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34
+
+
+# ============================================
+# QwenImage Edit (VL encoding with single image)
+# ============================================
+# Used for single-image editing where both image and text are encoded together
+
+QWENIMAGE_EDIT_PROMPT_TEMPLATE = (
+ "<|im_start|>system\n"
+ "Describe the key features of the input image (color, shape, size, texture, objects, background), "
+ "then explain how the user's text instruction should alter or modify the image. "
+ "Generate a new image that meets the user's requirements while maintaining consistency "
+ "with the original input where appropriate.<|im_end|>\n"
+ "<|im_start|>user\n"
+ "<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+)
+QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64
+
+
+# ============================================
+# QwenImage Edit Plus (VL encoding with multiple images)
+# ============================================
+# Used for multi-reference editing where multiple images and text are encoded together
+# The img_template is used to format each image in the prompt
+
+QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = (
+ "<|im_start|>system\n"
+ "Describe the key features of the input image (color, shape, size, texture, objects, background), "
+ "then explain how the user's text instruction should alter or modify the image. "
+ "Generate a new image that meets the user's requirements while maintaining consistency "
+ "with the original input where appropriate.<|im_end|>\n"
+ "<|im_start|>user\n{}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+)
+QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
+QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64
+
+
+# ============================================
+# QwenImage Layered (auto-captioning)
+# ============================================
+# Used for image decomposition where the VL model generates a caption from the input image
+# if no prompt is provided. These prompts instruct the model to describe the image in detail.
+
+QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = (
+ "<|im_start|>system\n"
+ "You are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>user\n"
+ "# Image Annotator\n"
+ "You are a professional image annotator. Please write an image caption based on the input image:\n"
+ "1. Write the caption using natural, descriptive language without structured formats or rich text.\n"
+ "2. Enrich caption details by including:\n"
+ " - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n"
+ " - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, "
+ "attachment relations, action relations, comparative relations, causal relations, and so on\n"
+ " - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n"
+ " - Identify the text clearly visible in the image, without translation or explanation, "
+ "and highlight it in the caption with quotation marks\n"
+ "3. Maintain authenticity and accuracy:\n"
+ " - Avoid generalizations\n"
+ " - Describe all visible information in the image, while do not add information not explicitly shown in the image\n"
+ "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
+ "<|im_start|>assistant\n"
+)
+
+QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = (
+ "<|im_start|>system\n"
+ "You are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>user\n"
+ "# 图像标注器\n"
+ "你是一个专业的图像标注器。请基于输入图像,撰写图注:\n"
+ "1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n"
+ "2. 通过加入以下内容,丰富图注细节:\n"
+ " - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n"
+ " - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n"
+ " - 环境细节:例如天气、光照、颜色、纹理、气氛等\n"
+ " - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n"
+ "3. 保持真实性与准确性:\n"
+ " - 不要使用笼统的描述\n"
+ " - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n"
+ "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
+ "<|im_start|>assistant\n"
+)
diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py
index b3b70b2f9be1..905111bcf42d 100644
--- a/src/diffusers/modular_pipelines/wan/modular_blocks.py
+++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py
@@ -84,7 +84,7 @@ def description(self):
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
- block_names = ["image_resize", "vae_image_encoder"]
+ block_names = ["image_resize", "vae_encoder"]
@property
def description(self):
@@ -142,7 +142,7 @@ def description(self):
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
- block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
+ block_names = ["image_resize", "last_image_resize", "vae_encoder"]
@property
def description(self):
@@ -203,7 +203,7 @@ def description(self):
## vae encoder
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
- block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
+ block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
@@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks):
block_names = [
"text_encoder",
"image_encoder",
- "vae_image_encoder",
+ "vae_encoder",
"denoise",
"decode",
]
@@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks):
]
block_names = [
"text_encoder",
- "vae_image_encoder",
+ "vae_encoder",
"denoise",
"decode",
]
@@ -384,7 +384,7 @@ def description(self):
[
("image_resize", WanImageResizeStep),
("image_encoder", WanImage2VideoImageEncoderStep),
- ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
+ ("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
@@ -401,7 +401,7 @@ def description(self):
("image_resize", WanImageResizeStep),
("last_image_resize", WanImageCropResizeStep),
("image_encoder", WanFLF2VImageEncoderStep),
- ("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
+ ("vae_encoder", WanFLF2VVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
@@ -416,7 +416,7 @@ def description(self):
[
("text_encoder", WanTextEncoderStep),
("image_encoder", WanAutoImageEncoderStep),
- ("vae_image_encoder", WanAutoVaeImageEncoderStep),
+ ("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", WanAutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
@@ -438,7 +438,7 @@ def description(self):
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("image_resize", WanImageResizeStep),
- ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
+ ("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
@@ -450,7 +450,7 @@ def description(self):
AUTO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
- ("vae_image_encoder", WanAutoVaeImageEncoderStep),
+ ("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", Wan22AutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index f7615c1a4439..65378631a172 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -15,6 +15,7 @@
is_torch_available,
is_torch_npu_available,
is_transformers_available,
+ is_transformers_version,
)
@@ -128,8 +129,8 @@
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
- _import_structure["bria_fibo"] = ["BriaFiboPipeline"]
- _import_structure["flux2"] = ["Flux2Pipeline"]
+ _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
+ _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -154,7 +155,7 @@
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
- _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
+ _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline", "ChromaInpaintPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
@@ -288,7 +289,9 @@
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
+ "LTXI2VLongMultiPromptPipeline",
]
+ _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
@@ -432,6 +435,8 @@
"QwenImageLayeredPipeline",
]
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
+ _import_structure["glm_image"] = ["GlmImagePipeline"]
+
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -592,8 +597,8 @@
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
- from .bria_fibo import BriaFiboPipeline
- from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
+ from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline
+ from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
from .chronoedit import ChronoEditPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
@@ -673,7 +678,8 @@
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
- from .flux2 import Flux2Pipeline
+ from .flux2 import Flux2KleinPipeline, Flux2Pipeline
+ from .glm_image import GlmImagePipeline
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
@@ -729,7 +735,14 @@
LEditsPPPipelineStableDiffusionXL,
)
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
- from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
+ from .ltx import (
+ LTXConditionPipeline,
+ LTXI2VLongMultiPromptPipeline,
+ LTXImageToVideoPipeline,
+ LTXLatentUpsamplePipeline,
+ LTXPipeline,
+ )
+ from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py
index 3be0129088fb..42083378d465 100644
--- a/src/diffusers/pipelines/allegro/pipeline_allegro.py
+++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py
@@ -887,7 +887,13 @@ def __call__(
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
index 6f3a609aba4a..51a9a31c4259 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
@@ -897,16 +897,20 @@ def __call__(
dtype = self.dtype
# 3. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
if not enforce_inference_steps:
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
else:
denoising_inference_steps = int(num_inference_steps / strength)
timesteps, denoising_inference_steps = retrieve_timesteps(
- self.scheduler, denoising_inference_steps, device, timesteps, sigmas
+ self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas
)
timesteps = timesteps[-num_inference_steps:]
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
index b00f344598ad..c3ac7df2cc8c 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
@@ -1100,16 +1100,20 @@ def __call__(
dtype = self.dtype
# 3. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
if not enforce_inference_steps:
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
else:
denoising_inference_steps = int(num_inference_steps / strength)
timesteps, denoising_inference_steps = retrieve_timesteps(
- self.scheduler, denoising_inference_steps, device, timesteps, sigmas
+ self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas
)
timesteps = timesteps[-num_inference_steps:]
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
index bb9884e41381..1d75e4bef31e 100644
--- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -586,7 +586,13 @@ def __call__(
# 4. Prepare timesteps
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index c14910250b54..5ee44190e23b 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -52,6 +52,8 @@
FluxKontextPipeline,
FluxPipeline,
)
+from .flux2 import Flux2KleinPipeline, Flux2Pipeline
+from .glm_image import GlmImagePipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -99,6 +101,7 @@
QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
+ QwenImageLayeredPipeline,
QwenImagePipeline,
)
from .sana import SanaPipeline
@@ -162,11 +165,14 @@
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("flux-kontext", FluxKontextPipeline),
+ ("flux2-klein", Flux2KleinPipeline),
+ ("flux2", Flux2Pipeline),
("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline),
("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
+ ("glm_image", GlmImagePipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),
@@ -199,9 +205,12 @@
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
("flux-kontext", FluxKontextPipeline),
+ ("flux2-klein", Flux2KleinPipeline),
+ ("flux2", Flux2Pipeline),
("qwenimage", QwenImageImg2ImgPipeline),
("qwenimage-edit", QwenImageEditPipeline),
("qwenimage-edit-plus", QwenImageEditPlusPipeline),
+ ("qwenimage-layered", QwenImageLayeredPipeline),
("z-image", ZImageImg2ImgPipeline),
]
)
diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py
index 206a463b394b..8dd77270902c 100644
--- a/src/diffusers/pipelines/bria_fibo/__init__.py
+++ b/src/diffusers/pipelines/bria_fibo/__init__.py
@@ -23,6 +23,8 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
+ _import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"]
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +35,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
+ from .pipeline_bria_fibo_edit import BriaFiboEditPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
new file mode 100644
index 000000000000..aae8fc7367da
--- /dev/null
+++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
@@ -0,0 +1,1133 @@
+# Copyright (c) Bria.ai. All rights reserved.
+#
+# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
+# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
+#
+# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
+# indicate if changes were made, and do not use the material for commercial purposes.
+#
+# See the license for further details.
+
+import json
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer
+from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin
+from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
+from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
+from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
+from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+PipelineMaskInput = Union[
+ torch.FloatTensor, Image.Image, List[Image.Image], List[torch.FloatTensor], np.ndarray, List[np.ndarray]
+]
+
+# TODO: Update example docstring
+EXAMPLE_DOC_STRING = """
+ Example:
+ ```python
+ import torch
+ from diffusers import BriaFiboEditPipeline
+ from diffusers.modular_pipelines import ModularPipeline
+
+ torch.set_grad_enabled(False)
+ vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
+ vlm_pipe = vlm_pipe.init_pipeline()
+
+ pipe = BriaFiboEditPipeline.from_pretrained(
+ "briaai/fibo-edit",
+ torch_dtype=torch.bfloat16,
+ )
+ pipe.to("cuda")
+
+ output = vlm_pipe(
+ prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality."
+ )
+ json_prompt_generate = json.loads(output.values["json_prompt"])
+
+ image = Image.open("image_generate.png")
+
+ edit_prompt = "Make the owl to be a cat"
+
+ json_prompt_generate["edit_instruction"] = edit_prompt
+
+ results_generate = pipe(
+ prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np"
+ )
+ ```
+"""
+
+PREFERRED_RESOLUTION = {
+ 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)],
+ 512 * 512: [
+ (416, 624),
+ (432, 592),
+ (464, 560),
+ (512, 512),
+ (544, 480),
+ (576, 448),
+ (592, 432),
+ (608, 416),
+ (624, 416),
+ (640, 400),
+ (672, 384),
+ (704, 368),
+ ],
+ 1024 * 1024: [
+ (832, 1248),
+ (880, 1184),
+ (912, 1136),
+ (1024, 1024),
+ (1136, 912),
+ (1184, 880),
+ (1216, 848),
+ (1248, 832),
+ (1248, 832),
+ (1264, 816),
+ (1296, 800),
+ (1360, 768),
+ ],
+}
+
+
+def is_valid_edit_json(json_input: str | dict):
+ """
+ Check if the input is a valid JSON string or dict with an "edit_instruction" key.
+
+ Args:
+ json_input (`str` or `dict`):
+ The JSON string or dict to check.
+
+ Returns:
+ `bool`: True if the input is a valid JSON string or dict with an "edit_instruction" key, False otherwise.
+ """
+ try:
+ if isinstance(json_input, str) and "edit_instruction" in json_input:
+ json.loads(json_input)
+ return True
+ elif isinstance(json_input, dict) and "edit_instruction" in json_input:
+ return True
+ else:
+ return False
+ except json.JSONDecodeError:
+ return False
+
+
+def is_valid_mask(mask: PipelineMaskInput):
+ """
+ Check if the mask is a valid mask.
+ """
+ if isinstance(mask, torch.Tensor):
+ return True
+ elif isinstance(mask, Image.Image):
+ return True
+ elif isinstance(mask, list):
+ return all(isinstance(m, (torch.Tensor, Image.Image, np.ndarray)) for m in mask)
+ elif isinstance(mask, np.ndarray):
+ return mask.ndim in [2, 3] and mask.min() >= 0 and mask.max() <= 1
+ else:
+ return False
+
+
+def get_mask_size(mask: PipelineMaskInput):
+ """
+ Get the size of the mask.
+ """
+ if isinstance(mask, torch.Tensor):
+ return mask.shape[-2:]
+ elif isinstance(mask, Image.Image):
+ return mask.size[::-1] # (height, width)
+ elif isinstance(mask, list):
+ return [get_mask_size(m) for m in mask]
+ elif isinstance(mask, np.ndarray):
+ return mask.shape[-2:]
+ else:
+ return None
+
+
+def get_image_size(image: PipelineImageInput):
+ """
+ Get the size of the image.
+ """
+ if isinstance(image, torch.Tensor):
+ return image.shape[-2:]
+ elif isinstance(image, Image.Image):
+ return image.size[::-1] # (height, width)
+ elif isinstance(image, list):
+ return [get_image_size(i) for i in image]
+ else:
+ return None
+
+
+def paste_mask_on_image(mask: PipelineMaskInput, image: PipelineImageInput):
+ """convert mask and image to PIL Images and paste the mask on the image"""
+ if isinstance(mask, torch.Tensor):
+ if mask.ndim == 3 and mask.shape[0] == 1:
+ mask = mask.squeeze(0)
+ mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8))
+ elif isinstance(mask, Image.Image):
+ pass
+ elif isinstance(mask, list):
+ mask = mask[0]
+ if isinstance(mask, torch.Tensor):
+ if mask.ndim == 3 and mask.shape[0] == 1:
+ mask = mask.squeeze(0)
+ mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8))
+ elif isinstance(mask, np.ndarray):
+ mask = Image.fromarray((mask * 255).astype(np.uint8))
+ elif isinstance(mask, np.ndarray):
+ mask = Image.fromarray((mask * 255).astype(np.uint8))
+
+ if isinstance(image, torch.Tensor):
+ if image.ndim == 3:
+ image = image.permute(1, 2, 0)
+ image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8))
+ elif isinstance(image, Image.Image):
+ pass
+ elif isinstance(image, list):
+ image = image[0]
+ if isinstance(image, torch.Tensor):
+ if image.ndim == 3:
+ image = image.permute(1, 2, 0)
+ image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8))
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray((image * 255).astype(np.uint8))
+ elif isinstance(image, np.ndarray):
+ image = Image.fromarray((image * 255).astype(np.uint8))
+
+ mask = mask.convert("L")
+ image = image.convert("RGB")
+ gray_color = (128, 128, 128)
+ gray_img = Image.new("RGB", image.size, gray_color)
+ image = Image.composite(gray_img, image, mask)
+ return image
+
+
+class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+ r"""
+ Args:
+ transformer (`BriaFiboTransformer2DModel`):
+ The transformer model for 2D diffusion modeling.
+ scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
+ Scheduler to be used with `transformer` to denoise the encoded latents.
+ vae (`AutoencoderKLWan`):
+ Variational Auto-Encoder for encoding and decoding images to and from latent representations.
+ text_encoder (`SmolLM3ForCausalLM`):
+ Text encoder for processing input prompts.
+ tokenizer (`AutoTokenizer`):
+ Tokenizer used for processing the input text prompts for the text_encoder.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: BriaFiboTransformer2DModel,
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
+ vae: AutoencoderKLWan,
+ text_encoder: SmolLM3ForCausalLM,
+ tokenizer: AutoTokenizer,
+ ):
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor = 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2)
+ self.default_sample_size = 32 # 64
+
+ def get_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 2048,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if not prompt:
+ raise ValueError("`prompt` must be a non-empty string or list of strings.")
+
+ batch_size = len(prompt)
+ bot_token_id = 128000
+
+ text_encoder_device = device if device is not None else torch.device("cpu")
+ if not isinstance(text_encoder_device, torch.device):
+ text_encoder_device = torch.device(text_encoder_device)
+
+ if all(p == "" for p in prompt):
+ input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
+ attention_mask = torch.ones_like(input_ids)
+ else:
+ tokenized = self.tokenizer(
+ prompt,
+ padding="longest",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ input_ids = tokenized.input_ids.to(text_encoder_device)
+ attention_mask = tokenized.attention_mask.to(text_encoder_device)
+
+ if any(p == "" for p in prompt):
+ empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
+ input_ids[empty_rows] = bot_token_id
+ attention_mask[empty_rows] = 1
+
+ encoder_outputs = self.text_encoder(
+ input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_outputs.hidden_states
+
+ prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ hidden_states = tuple(
+ layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
+ )
+ attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
+
+ return prompt_embeds, hidden_states, attention_mask
+
+ @staticmethod
+ def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
+ # Pad embeddings to `max_tokens` while preserving the mask of real tokens.
+ batch_size, seq_len, dim = prompt_embeds.shape
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
+ else:
+ attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+
+ if max_tokens < seq_len:
+ raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
+
+ if max_tokens > seq_len:
+ pad_length = max_tokens - seq_len
+ padding = torch.zeros(
+ (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
+
+ mask_padding = torch.zeros(
+ (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
+
+ return prompt_embeds, attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 3000,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ guidance_scale (`float`):
+ Guidance scale for classifier free guidance.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ prompt_attention_mask = None
+ negative_prompt_attention_mask = None
+ if prompt_embeds is None:
+ prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
+ prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
+
+ if guidance_scale > 1:
+ if isinstance(negative_prompt, list) and negative_prompt[0] is None:
+ negative_prompt = ""
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
+ negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ # Pad to longest
+ if prompt_attention_mask is not None:
+ prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+
+ if negative_prompt_embeds is not None:
+ if negative_prompt_attention_mask is not None:
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(
+ device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
+ )
+ max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
+
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
+ )
+ prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
+ negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
+ else:
+ max_tokens = prompt_embeds.shape[1]
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
+ )
+ negative_prompt_layers = None
+
+ dtype = self.text_encoder.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ prompt_layers,
+ negative_prompt_layers,
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @staticmethod
+ # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels)
+ latents = latents.permute(0, 3, 1, 2)
+
+ return latents
+
+ @staticmethod
+ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.permute(0, 2, 3, 1)
+ latents = latents.reshape(batch_size, height * width, num_channels_latents)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ do_patching=False,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ if do_patching:
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ else:
+ latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @staticmethod
+ def _prepare_attention_mask(attention_mask):
+ attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
+
+ # convert to 0 - keep, -inf ignore
+ attention_matrix = torch.where(
+ attention_matrix == 1, 0.0, -torch.inf
+ ) # Apply -inf to ignored tokens for nulling softmax score
+ return attention_matrix
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[PipelineImageInput] = None,
+ mask: Optional[PipelineMaskInput] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 30,
+ timesteps: List[int] = None,
+ seed: Optional[int] = None,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 3000,
+ do_patching=False,
+ _auto_resize: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*):
+ The image to guide the image generation. If not defined, the pipeline will generate an image from
+ scratch.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ seed (`int`, *optional*):
+ A seed used to make generation deterministic.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
+ do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
+ Examples:
+ Returns:
+ [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if height is None or width is None:
+ if image is not None:
+ image_height, image_width = self.image_processor.get_default_height_width(image)
+ if _auto_resize:
+ image_width, image_height = min(
+ PREFERRED_RESOLUTION[1024 * 1024],
+ key=lambda size: abs(size[0] / size[1] - image_width / image_height),
+ )
+ width, height = image_width, image_height
+ else:
+ raise ValueError("You must provide either an image or both height and width.")
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ seed=seed,
+ image=image,
+ mask=mask,
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if mask is not None and image is not None:
+ image = paste_mask_on_image(mask, image)
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+
+ if prompt is not None and is_valid_edit_json(prompt):
+ prompt = json.dumps(prompt)
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ if generator is None and seed is not None:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ prompt_layers,
+ negative_prompt_layers,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=guidance_scale,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ num_images_per_prompt=num_images_per_prompt,
+ lora_scale=lora_scale,
+ )
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if guidance_scale > 1:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_layers = [
+ torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
+ ]
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
+ self.transformer.single_transformer_blocks
+ )
+ if len(prompt_layers) >= total_num_layers_transformer:
+ # remove first layers
+ prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
+ else:
+ # duplicate last layer
+ prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
+
+ # Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ image = self.image_processor.resize(image, height, width)
+ image = self.image_processor.preprocess(image, height, width)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ if do_patching:
+ num_channels_latents = int(num_channels_latents / 4)
+
+ latents, latent_image_ids = self.prepare_latents(
+ prompt_batch_size,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ do_patching,
+ )
+
+ if image is not None:
+ image_latents, image_ids = self.prepare_image_latents(
+ image=image,
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ )
+ latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension
+ else:
+ image_latents = None
+
+ latent_attention_mask = torch.ones(
+ [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
+ )
+ if guidance_scale > 1:
+ latent_attention_mask = latent_attention_mask.repeat(2, 1)
+
+ if image_latents is None:
+ attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
+ else:
+ image_latent_attention_mask = torch.ones(
+ [image_latents.shape[0], image_latents.shape[1]],
+ dtype=image_latents.dtype,
+ device=image_latents.device,
+ )
+ if guidance_scale > 1:
+ image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1)
+ attention_mask = torch.cat(
+ [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1
+ )
+
+ attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq
+ attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
+
+ if self._joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+ self._joint_attention_kwargs["attention_mask"] = attention_mask
+
+ # Adapt scheduler to dynamic shifting (resolution dependent)
+
+ if do_patching:
+ seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
+ else:
+ seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+
+ mu = calculate_shift(
+ seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+
+ # Init sigmas and timesteps according to shift size
+ # This changes the scheduler in-place according to the dynamic scheduling
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps=num_inference_steps,
+ device=device,
+ timesteps=None,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Support old different diffusers versions
+ if len(latent_image_ids.shape) == 3:
+ latent_image_ids = latent_image_ids[0]
+
+ if len(text_ids.shape) == 3:
+ text_ids = text_ids[0]
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = latents
+
+ if image_latents is not None:
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(
+ device=latent_model_input.device, dtype=latent_model_input.dtype
+ )
+
+ # This is predicts "v" from flow-matching or eps from diffusion
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ text_encoder_layers=prompt_layers,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ )[0]
+
+ # perform guidance
+ if guidance_scale > 1:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ if do_patching:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ else:
+ latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
+
+ latents = latents.unsqueeze(dim=2)
+ latents_device = latents[0].device
+ latents_dtype = latents[0].dtype
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents_device, latents_dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents_device, latents_dtype
+ )
+ latents_scaled = [latent / latents_std + latents_mean for latent in latents]
+ latents_scaled = torch.cat(latents_scaled, dim=0)
+ image = []
+ for scaled_latent in latents_scaled:
+ curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
+ curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
+ image.append(curr_image)
+ if len(image) == 1:
+ image = image[0]
+ else:
+ image = np.stack(image, axis=0)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return BriaFiboPipelineOutput(images=image)
+
+ def prepare_image_latents(
+ self,
+ image: torch.Tensor,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ):
+ image = image.to(device=device, dtype=dtype)
+
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ # scaling
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, dtype
+ )
+
+ image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean
+ latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw]
+ image_latents_cthw = torch.concat(latents_scaled, dim=0)
+ image_latents_bchw = image_latents_cthw[:, :, 0, :, :]
+
+ image_latent_height, image_latent_width = image_latents_bchw.shape[2:]
+ image_latents_bsd = self._pack_latents_no_patch(
+ latents=image_latents_bchw,
+ batch_size=batch_size,
+ num_channels_latents=num_channels_latents,
+ height=image_latent_height,
+ width=image_latent_width,
+ )
+ # breakpoint()
+ image_ids = self._prepare_latent_image_ids(
+ batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+ return image_latents_bsd, image_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ seed,
+ image,
+ mask,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if seed is not None and not isinstance(seed, int):
+ raise ValueError("Seed must be an integer")
+ if image is not None and not isinstance(image, (torch.Tensor, Image.Image, list)):
+ raise ValueError("Image must be a valid image")
+ if image is None and mask is not None:
+ raise ValueError("If mask is provided, image must also be provided")
+
+ if mask is not None and not is_valid_mask(mask):
+ raise ValueError("Mask must be a valid mask")
+
+ if mask is not None and image is not None and not (get_mask_size(mask) == get_image_size(image)):
+ raise ValueError("Mask and image must have the same size")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and not is_valid_edit_json(prompt):
+ raise ValueError(f"`prompt` has to be a valid JSON string or dict but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 3000:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
+
+ def create_attention_matrix(self, attention_mask):
+ attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
+
+ # convert to 0 - keep, -inf ignore
+ attention_matrix = torch.where(
+ attention_matrix == 1, 0.0, -torch.inf
+ ) # Apply -inf to ignored tokens for nulling softmax score
+ return attention_matrix
diff --git a/src/diffusers/pipelines/chroma/__init__.py b/src/diffusers/pipelines/chroma/__init__.py
index d9238b735c41..25069b5543c1 100644
--- a/src/diffusers/pipelines/chroma/__init__.py
+++ b/src/diffusers/pipelines/chroma/__init__.py
@@ -24,6 +24,7 @@
else:
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
+ _import_structure["pipeline_chroma_inpainting"] = ["ChromaInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -33,6 +34,7 @@
else:
from .pipeline_chroma import ChromaPipeline
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
+ from .pipeline_chroma_inpainting import ChromaInpaintPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py
new file mode 100644
index 000000000000..019c14415202
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py
@@ -0,0 +1,1197 @@
+"""
+ChromaInpaintPipeline implements a text-guided image inpainting pipeline for the lodestones/Chroma1-HD model, based on
+the ChromaPipeline from Hugging Face Diffusers:contentReference[oaicite:0]{index=0} and the Stable Diffusion inpainting
+approach:contentReference[oaicite:1]{index=1}.
+"""
+
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import ChromaTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..chroma.pipeline_output import ChromaPipelineOutput
+from ..pipeline_utils import DiffusionPipeline
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import ChromaInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = ChromaInpaintPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0]
+ >>> image.save("chroma_inpainting.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ChromaInpaintPipeline(
+ DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, FluxIPAdapterMixin
+):
+ r"""
+ The Flux pipeline for image inpainting.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`ChromaTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: ChromaTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.default_sample_size = 128
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str], None] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ tokenizer_mask = text_inputs.attention_mask
+
+ tokenizer_mask = tokenizer_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=False,
+ attention_mask=tokenizer_mask,
+ )[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ seq_lengths = tokenizer_mask.sum(dim=1)
+ mask_indices = torch.arange(tokenizer_mask.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
+ attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str], None] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ do_classifier_free_guidance: bool = True,
+ max_sequence_length: int = 256,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3, device=device, dtype=dtype)
+ negative_text_ids = None
+
+ if do_classifier_free_guidance:
+ if negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = (
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3, device=device, dtype=dtype)
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return (
+ prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_text_ids,
+ negative_prompt_attention_mask,
+ )
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError(
+ "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, noise, image_latents, latent_image_ids
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (
+ masked_image_latents - self.vae.config.shift_factor
+ ) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def _prepare_attention_mask(
+ self,
+ batch_size,
+ sequence_length,
+ dtype,
+ attention_mask=None,
+ ):
+ if attention_mask is None:
+ return attention_mask
+
+ # Extend the prompt attention mask to account for image tokens in the final sequence
+ attention_mask = torch.cat(
+ [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
+ dim=1,
+ )
+ attention_mask = attention_mask.to(dtype)
+
+ return attention_mask
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 1.0,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ not greater than `1`).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 35):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ strength (`float, *optional*, defaults to 0.9):
+ Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
+ be used as a starting point, adding more noise to it the larger the strength. The number of denoising
+ steps depends on the amount of noise initially added. When strength is 1, added noise will be maximum
+ and the denoising process will run for the full number of iterations specified in num_inference_steps.
+ A value of 1, therefore, essentially ignores image.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (torch.Tensor, *optional*):
+ Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
+ Chroma requires a single padding token remain unmasked. Please refer to
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
+ negative_prompt_attention_mask (torch.Tensor, *optional*):
+ Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
+ prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.chroma.ChromaPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ output_type=output_type,
+ strength=strength,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ image=image,
+ mask_image=mask_image,
+ padding_mask_crop=padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_text_ids,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ num_channels_transformer = self.transformer.config.in_channels
+
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ attention_mask = self._prepare_attention_mask(
+ batch_size=latents.shape[0],
+ sequence_length=image_seq_len,
+ dtype=latents.dtype,
+ attention_mask=prompt_attention_mask,
+ )
+ negative_attention_mask = self._prepare_attention_mask(
+ batch_size=latents.shape[0],
+ sequence_length=image_seq_len,
+ dtype=latents.dtype,
+ attention_mask=negative_prompt_attention_mask,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ attention_mask=attention_mask,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+
+ noise_pred_uncond = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ attention_mask=negative_attention_mask,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ChromaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index 4ac33b24bbe1..245c794c9c93 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -664,7 +664,13 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
self._num_timesteps = len(timesteps)
# 5. Prepare latents
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index c1335839f848..456f0bda1644 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -717,7 +717,13 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
self._num_timesteps = len(timesteps)
# 5. Prepare latents
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index c523c9adec98..321f0f073fe7 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -762,7 +762,13 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
self._num_timesteps = len(timesteps)
# 5. Prepare latents
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index 897dc6d1b70a..e27c572020d6 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -737,7 +737,13 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 304a5c5ad00b..46f60d24a467 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -566,7 +566,13 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 22510f5d9d50..9a2d555538d5 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -599,8 +599,12 @@ def __call__(
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu
)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
index e26b7ba415de..2d6785f791db 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
@@ -649,8 +649,12 @@ def __call__(
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu
)
self._num_timesteps = len(timesteps)
# Denoising loop
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index fe0e69314cca..e2fb32688392 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -1195,8 +1195,12 @@ def __call__(
assert False
# 5. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 0e2a1441f8f6..283c3f92390c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -1344,8 +1344,12 @@ def __call__(
assert False
# 5. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index 94c4c394465b..2ea7307fec32 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -84,7 +84,6 @@
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
-
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
index 40cc76cf70d8..99f2958b320e 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
@@ -1339,8 +1339,12 @@ def __call__(
height, width = control_image[0][0].shape[-2:]
# 5. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index d605eac1f2b1..d721acc77c2a 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -1098,7 +1098,13 @@ def __call__(
assert False
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index 9d0158c6b654..2071305cdf10 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -15,6 +15,8 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import numpy as np
+import PIL.Image
import torch
from transformers import (
CLIPTextModelWithProjection,
@@ -39,7 +41,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import randn_tensor
+from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
@@ -227,6 +229,8 @@ def __init__(
feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = SD3MultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
@@ -572,14 +576,52 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
def check_inputs(
self,
+ height,
+ width,
+ image,
prompt,
prompt_2,
prompt_3,
- height,
- width,
negative_prompt=None,
negative_prompt_2=None,
negative_prompt_3=None,
@@ -587,6 +629,11 @@ def check_inputs(
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -669,6 +716,76 @@ def check_inputs(
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, SD3MultiControlNetModel):
+ if isinstance(prompt, list) and len(prompt) > 1:
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, SD3ControlNetModel):
+ self.check_image(image, prompt, prompt_embeds)
+ elif isinstance(controlnet, SD3MultiControlNetModel):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+
+ # Check `controlnet_conditioning_scale`
+ if isinstance(controlnet, SD3MultiControlNetModel):
+ if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(controlnet, SD3MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
def prepare_latents(
self,
@@ -1040,11 +1157,12 @@ def __call__(
# 1. Check inputs. Raise error if not correct
self.check_inputs(
+ height,
+ width,
+ control_image,
prompt,
prompt_2,
prompt_3,
- height,
- width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
@@ -1052,6 +1170,11 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ ip_adapter_image=ip_adapter_image,
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ control_guidance_start=control_guidance_start,
+ control_guidance_end=control_guidance_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -1119,9 +1242,26 @@ def __call__(
width = latent_width * self.vae_scale_factor
elif isinstance(self.controlnet, SD3MultiControlNetModel):
- raise NotImplementedError("MultiControlNetModel is not supported for SD3ControlNetInpaintingPipeline.")
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_image_with_mask(
+ image=control_image_,
+ mask=control_mask,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=False,
+ )
+ control_images.append(control_image_)
+
+ control_image = control_images
else:
- assert False
+ assert ValueError("Controlnet not found. Please check the controlnet model.")
if controlnet_pooled_projections is None:
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
@@ -1129,7 +1269,13 @@ def __call__(
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py
index 372684e0b521..ea9df999ddd6 100644
--- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py
+++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py
@@ -76,7 +76,7 @@ def retrieve_latents(
>>> model_id = "nvidia/Cosmos-Predict2.5-2B"
>>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
- ... model_id, revision="diffusers/base/pre-trianed", torch_dtype=torch.bfloat16
+ ... model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
... )
>>> pipe = pipe.to("cuda")
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
index 92239c0d32f0..86c4d6812130 100755
--- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
@@ -666,12 +666,18 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, mu=1
+ self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1
)
else:
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
index f74a11f87d75..b28a2c9fb273 100755
--- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
@@ -810,12 +810,18 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, mu=1
+ self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1
)
else:
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
index b16ef92d8e6b..ec394315ee93 100755
--- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
@@ -956,12 +956,18 @@ def __call__(
)
# 4. set timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, mu=1
+ self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1
)
else:
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 5041e352f73d..9562722dbee3 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -876,10 +876,15 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
index 848d7bd39254..77f971d57a80 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -829,10 +829,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
index 262345c75afc..e1bbc6735051 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
@@ -810,10 +810,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
index 6915a83a7ca7..b02e74d3b2d6 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
@@ -1013,10 +1013,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index 507ec687347c..78de4f617f84 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -1002,10 +1002,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index 582c7bbad84e..5bf593258f49 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -873,10 +873,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index f7f34ef231af..a1e1f5f5e9e5 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -1020,10 +1020,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
index 5cb9c82204b2..8ec9871d2579 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
@@ -932,10 +932,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
index ab9140dae921..5166a6497e01 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
@@ -940,10 +940,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index 3bfe82cf4382..64a81fb0699f 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -1015,10 +1015,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py
index d986c9a63011..f6e1d5206630 100644
--- a/src/diffusers/pipelines/flux2/__init__.py
+++ b/src/diffusers/pipelines/flux2/__init__.py
@@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
+ _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -31,6 +32,7 @@
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_flux2 import Flux2Pipeline
+ from .pipeline_flux2_klein import Flux2KleinPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py
index b54a43dd89a5..c01b7137e086 100644
--- a/src/diffusers/pipelines/flux2/pipeline_flux2.py
+++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py
@@ -725,8 +725,8 @@ def guidance_scale(self):
return self._guidance_scale
@property
- def joint_attention_kwargs(self):
- return self._joint_attention_kwargs
+ def attention_kwargs(self):
+ return self._attention_kwargs
@property
def num_timesteps(self):
@@ -975,7 +975,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
- joint_attention_kwargs=self._attention_kwargs,
+ joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py
new file mode 100644
index 000000000000..efb0aebf8593
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py
@@ -0,0 +1,918 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
+
+from ...loaders import Flux2LoraLoaderMixin
+from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .image_processor import Flux2ImageProcessor
+from .pipeline_output import Flux2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import Flux2KleinPipeline
+
+ >>> pipe = Flux2KleinPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=4.0).images[0]
+ >>> image.save("flux2_output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
+def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
+ a1, b1 = 8.73809524e-05, 1.89833333
+ a2, b2 = 0.00016927, 0.45666666
+
+ if image_seq_len > 4300:
+ mu = a2 * image_seq_len + b2
+ return float(mu)
+
+ m_200 = a2 * image_seq_len + b2
+ m_10 = a1 * image_seq_len + b1
+
+ a = (m_200 - m_10) / 190.0
+ b = m_200 - 200.0 * a
+ mu = a * num_steps + b
+
+ return float(mu)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
+ r"""
+ The Flux2 Klein pipeline for text-to-image generation.
+
+ Reference:
+ [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
+
+ Args:
+ transformer ([`Flux2Transformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLFlux2`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen3ForCausalLM`]):
+ [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
+ tokenizer (`Qwen2TokenizerFast`):
+ Tokenizer of class
+ [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLFlux2,
+ text_encoder: Qwen3ForCausalLM,
+ tokenizer: Qwen2TokenizerFast,
+ transformer: Flux2Transformer2DModel,
+ is_distilled: bool = False,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer,
+ )
+
+ self.register_to_config(is_distilled=is_distilled)
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 512
+ self.default_sample_size = 128
+
+ @staticmethod
+ def _get_qwen3_prompt_embeds(
+ text_encoder: Qwen3ForCausalLM,
+ tokenizer: Qwen2TokenizerFast,
+ prompt: Union[str, List[str]],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 512,
+ hidden_states_layers: List[int] = (9, 18, 27),
+ ):
+ dtype = text_encoder.dtype if dtype is None else dtype
+ device = text_encoder.device if device is None else device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ all_input_ids = []
+ all_attention_masks = []
+
+ for single_prompt in prompt:
+ messages = [{"role": "user", "content": single_prompt}]
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ inputs = tokenizer(
+ text,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_sequence_length,
+ )
+
+ all_input_ids.append(inputs["input_ids"])
+ all_attention_masks.append(inputs["attention_mask"])
+
+ input_ids = torch.cat(all_input_ids, dim=0).to(device)
+ attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
+
+ # Forward pass through the model
+ output = text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ use_cache=False,
+ )
+
+ # Only use outputs from intermediate layers and stack them
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
+ out = out.to(dtype=dtype, device=device)
+
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
+
+ return prompt_embeds
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
+ def _prepare_text_ids(
+ x: torch.Tensor, # (B, L, D) or (L, D)
+ t_coord: Optional[torch.Tensor] = None,
+ ):
+ B, L, _ = x.shape
+ out_ids = []
+
+ for i in range(B):
+ t = torch.arange(1) if t_coord is None else t_coord[i]
+ h = torch.arange(1)
+ w = torch.arange(1)
+ l = torch.arange(L)
+
+ coords = torch.cartesian_prod(t, h, w, l)
+ out_ids.append(coords)
+
+ return torch.stack(out_ids)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
+ def _prepare_latent_ids(
+ latents: torch.Tensor, # (B, C, H, W)
+ ):
+ r"""
+ Generates 4D position coordinates (T, H, W, L) for latent tensors.
+
+ Args:
+ latents (torch.Tensor):
+ Latent tensor of shape (B, C, H, W)
+
+ Returns:
+ torch.Tensor:
+ Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
+ H=[0..H-1], W=[0..W-1], L=0
+ """
+
+ batch_size, _, height, width = latents.shape
+
+ t = torch.arange(1) # [0] - time dimension
+ h = torch.arange(height)
+ w = torch.arange(width)
+ l = torch.arange(1) # [0] - layer dimension
+
+ # Create position IDs: (H*W, 4)
+ latent_ids = torch.cartesian_prod(t, h, w, l)
+
+ # Expand to batch: (B, H*W, 4)
+ latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
+
+ return latent_ids
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
+ def _prepare_image_ids(
+ image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
+ scale: int = 10,
+ ):
+ r"""
+ Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
+
+ This function creates a unique coordinate for every pixel/patch across all input latent with different
+ dimensions.
+
+ Args:
+ image_latents (List[torch.Tensor]):
+ A list of image latent feature tensors, typically of shape (C, H, W).
+ scale (int, optional):
+ A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
+ latent is: 'scale + scale * i'. Defaults to 10.
+
+ Returns:
+ torch.Tensor:
+ The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
+ input latents.
+
+ Coordinate Components (Dimension 4):
+ - T (Time): The unique index indicating which latent image the coordinate belongs to.
+ - H (Height): The row index within that latent image.
+ - W (Width): The column index within that latent image.
+ - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
+ """
+
+ if not isinstance(image_latents, list):
+ raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
+
+ # create time offset for each reference image
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
+ t_coords = [t.view(-1) for t in t_coords]
+
+ image_latent_ids = []
+ for x, t in zip(image_latents, t_coords):
+ x = x.squeeze(0)
+ _, height, width = x.shape
+
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
+ image_latent_ids.append(x_ids)
+
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
+ image_latent_ids = image_latent_ids.unsqueeze(0)
+
+ return image_latent_ids
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
+ def _patchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
+ def _unpatchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
+ def _pack_latents(latents):
+ """
+ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
+ """
+
+ batch_size, num_channels, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
+ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
+ """
+ using position ids to scatter tokens into place
+ """
+ x_list = []
+ for data, pos in zip(x, x_ids):
+ _, ch = data.shape # noqa: F841
+ h_ids = pos[:, 1].to(torch.int64)
+ w_ids = pos[:, 2].to(torch.int64)
+
+ h = torch.max(h_ids) + 1
+ w = torch.max(w_ids) + 1
+
+ flat_ids = h_ids * w + w_ids
+
+ out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
+ out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
+
+ # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
+
+ out = out.view(h, w, ch).permute(2, 0, 1)
+ x_list.append(out)
+
+ return torch.stack(x_list, dim=0)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: Tuple[int] = (9, 18, 27),
+ ):
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = ""
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_qwen3_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ hidden_states_layers=text_encoder_out_layers,
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ text_ids = self._prepare_text_ids(prompt_embeds)
+ text_ids = text_ids.to(device)
+ return prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if image.ndim != 4:
+ raise ValueError(f"Expected image dims 4, got {image.ndim}.")
+
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ image_latents = self._patchify_latents(image_latents)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
+ image_latents = (image_latents - latents_bn_mean) / latents_bn_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_latents_channels,
+ height,
+ width,
+ dtype,
+ device,
+ generator: torch.Generator,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ latent_ids = self._prepare_latent_ids(latents)
+ latent_ids = latent_ids.to(device)
+
+ latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
+ return latents, latent_ids
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
+ def prepare_image_latents(
+ self,
+ images: List[torch.Tensor],
+ batch_size,
+ generator: torch.Generator,
+ device,
+ dtype,
+ ):
+ image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ imagge_latent = self._encode_vae_image(image=image, generator=generator)
+ image_latents.append(imagge_latent) # (1, 128, 32, 32)
+
+ image_latent_ids = self._prepare_image_ids(image_latents)
+
+ # Pack each latent and concatenate
+ packed_latents = []
+ for latent in image_latents:
+ # latent: (1, 128, 32, 32)
+ packed = self._pack_latents(latent) # (1, 1024, 128)
+ packed = packed.squeeze(0) # (1024, 128) - remove batch dim
+ packed_latents.append(packed)
+
+ # Concatenate all reference tokens along sequence dimension
+ image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
+ image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
+
+ image_latents = image_latents.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.to(device)
+
+ return image_latents, image_latent_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ guidance_scale=None,
+ ):
+ if (
+ height is not None
+ and height % (self.vae_scale_factor * 2) != 0
+ or width is not None
+ and width % (self.vae_scale_factor * 2) != 0
+ ):
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if guidance_scale > 1.0 and self.config.is_distilled:
+ logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and not self.config.is_distilled
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[Union[str, List[str]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: Tuple[int] = (9, 18, 27),
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models,
+ `guidance_scale` is ignored.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline.
+ If not provided, will be generated from "".
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+ text_encoder_out_layers (`Tuple[int]`):
+ Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ guidance_scale=guidance_scale,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. prepare text embeddings
+ prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ text_encoder_out_layers=text_encoder_out_layers,
+ )
+
+ if self.do_classifier_free_guidance:
+ negative_prompt = ""
+ if prompt is not None and isinstance(prompt, list):
+ negative_prompt = [negative_prompt] * len(prompt)
+ negative_prompt_embeds, negative_text_ids = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ text_encoder_out_layers=text_encoder_out_layers,
+ )
+
+ # 4. process images
+ if image is not None and not isinstance(image, list):
+ image = [image]
+
+ condition_images = None
+ if image is not None:
+ for img in image:
+ self.image_processor.check_image_input(img)
+
+ condition_images = []
+ for img in image:
+ image_width, image_height = img.size
+ if image_width * image_height > 1024 * 1024:
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
+ image_width, image_height = img.size
+
+ multiple_of = self.vae_scale_factor * 2
+ image_width = (image_width // multiple_of) * multiple_of
+ image_height = (image_height // multiple_of) * multiple_of
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
+ condition_images.append(img)
+ height = height or image_height
+ width = width or image_width
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 5. prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_ids = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_latents_channels=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ image_latents = None
+ image_latent_ids = None
+ if condition_images is not None:
+ image_latents, image_latent_ids = self.prepare_image_latents(
+ images=condition_images,
+ batch_size=batch_size * num_images_per_prompt,
+ generator=generator,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ # 6. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
+ image_seq_len = latents.shape[1]
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 7. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ latent_model_input = latents.to(self.transformer.dtype)
+ latent_image_ids = latent_ids
+
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input, # (B, image_seq_len, C)
+ timestep=timestep / 1000,
+ guidance=None,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=latent_image_ids, # B, image_seq_len, 4
+ joint_attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred[:, : latents.size(1) :]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=None,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self._attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1) :]
+ noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
+ latents.device, latents.dtype
+ )
+ latents = latents * latents_bn_std + latents_bn_mean
+ latents = self._unpatchify_latents(latents)
+ if output_type == "latent":
+ image = latents
+ else:
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return Flux2PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py
new file mode 100644
index 000000000000..140b9cc760cc
--- /dev/null
+++ b/src/diffusers/pipelines/glm_image/__init__.py
@@ -0,0 +1,59 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
+
+# Import transformers components so they can be resolved during pipeline loading
+
+if is_transformers_available() and is_transformers_version(">=", "4.57.4"):
+ try:
+ from transformers import GlmImageForConditionalGeneration, GlmImageProcessor
+
+ _additional_imports["GlmImageForConditionalGeneration"] = GlmImageForConditionalGeneration
+ _additional_imports["GlmImageProcessor"] = GlmImageProcessor
+ except ImportError:
+ pass
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_glm_image import GlmImagePipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py
new file mode 100644
index 000000000000..5499b8769fa6
--- /dev/null
+++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py
@@ -0,0 +1,825 @@
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import ByT5Tokenizer, PreTrainedModel, ProcessorMixin, T5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, GlmImageTransformer2DModel
+from ...models.transformers.transformer_glm_image import GlmImageKVCache
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, is_transformers_version, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from .pipeline_output import GlmImagePipelineOutput
+
+
+# Because it's not released in stable as of 13/01/2026. So this is just a proxy.
+GlmImageProcessor = ProcessorMixin
+GlmImageForConditionalGeneration = PreTrainedModel
+if is_transformers_version(">=", "5.0.0.dev0"):
+ from transformers import GlmImageForConditionalGeneration, GlmImageProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import GlmImagePipeline
+
+ >>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ base_shift: float = 0.25,
+ max_shift: float = 0.75,
+) -> float:
+ m = (image_seq_len / base_seq_len) ** 0.5
+ mu = m * max_shift + base_shift
+ return mu
+
+
+# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+
+ if timesteps is not None and sigmas is not None:
+ if not accepts_timesteps and not accepts_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif timesteps is not None and sigmas is None:
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif timesteps is None and sigmas is not None:
+ if not accepts_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class GlmImagePipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using GLM-Image.
+
+ This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion
+ transformer) model for image decoding.
+
+ Args:
+ tokenizer (`PreTrainedTokenizer`):
+ Tokenizer for the text encoder.
+ processor (`AutoProcessor`):
+ Processor for the AR model to handle chat templates and tokenization.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder for glyph embeddings.
+ vision_language_encoder ([`GlmImageForConditionalGeneration`]):
+ The AR model that generates image tokens from text prompts.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ transformer ([`GlmImageTransformer2DModel`]):
+ A text conditioned transformer to denoise the encoded image latents (DiT).
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "vision_language_encoder->text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: ByT5Tokenizer,
+ processor: GlmImageProcessor,
+ text_encoder: T5EncoderModel,
+ vision_language_encoder: GlmImageForConditionalGeneration,
+ vae: AutoencoderKL,
+ transformer: GlmImageTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ processor=processor,
+ text_encoder=text_encoder,
+ vision_language_encoder=vision_language_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer")
+ and self.transformer is not None
+ and hasattr(self.transformer.config, "sample_size")
+ else 128
+ )
+
+ @staticmethod
+ def _compute_generation_params(
+ image_grid_thw,
+ is_text_to_image: bool,
+ ):
+ grid_sizes = []
+ grid_hw = []
+
+ for i in range(image_grid_thw.shape[0]):
+ t, h, w = image_grid_thw[i].tolist()
+ grid_sizes.append(int(h * w))
+ grid_hw.append((int(h), int(w)))
+
+ if not is_text_to_image:
+ max_new_tokens = grid_sizes[-1] + 1
+ large_image_start_offset = 0
+ target_grid_h, target_grid_w = grid_hw[-1]
+ else:
+ total_tokens = sum(grid_sizes)
+ max_new_tokens = total_tokens + 1
+ large_image_start_offset = sum(grid_sizes[1:])
+ target_grid_h, target_grid_w = grid_hw[0]
+ return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w
+
+ @staticmethod
+ def _extract_large_image_tokens(
+ outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int
+ ) -> torch.Tensor:
+ generated_tokens = outputs[0][input_length:]
+ large_image_start = large_image_start_offset
+ large_image_end = large_image_start + large_image_tokens
+ return generated_tokens[large_image_start:large_image_end]
+
+ @staticmethod
+ def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
+ token_ids = token_ids.view(1, 1, token_h, token_w)
+ token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(
+ dtype=torch.long
+ )
+ token_ids = token_ids.view(1, -1)
+ return token_ids
+
+ def generate_prior_tokens(
+ self,
+ prompt: str,
+ height: int,
+ width: int,
+ image: Optional[List[PIL.Image.Image]] = None,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ is_text_to_image = image is None or len(image) == 0
+ content = []
+ if image is not None:
+ for img in image:
+ content.append({"type": "image", "image": img})
+ content.append({"type": "text", "text": prompt})
+ messages = [{"role": "user", "content": content}]
+ inputs = self.processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ target_h=height,
+ target_w=width,
+ return_dict=True,
+ return_tensors="pt",
+ ).to(device)
+
+ image_grid_thw = inputs.get("image_grid_thw")
+ max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params(
+ image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image
+ )
+
+ prior_token_image_ids = None
+ if image is not None:
+ prior_token_image_embed = self.vision_language_encoder.get_image_features(
+ inputs["pixel_values"], image_grid_thw[:-1]
+ )
+ prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
+ prior_token_image_ids = self.vision_language_encoder.get_image_tokens(
+ prior_token_image_embed, image_grid_thw[:-1]
+ )
+
+ # For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.
+ # max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).
+ outputs = self.vision_language_encoder.generate(
+ **inputs,
+ max_new_tokens=max_new_tokens,
+ do_sample=True,
+ )
+
+ prior_token_ids_d32 = self._extract_large_image_tokens(
+ outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w
+ )
+ prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
+
+ return prior_token_ids, prior_token_image_ids
+
+ def get_glyph_texts(self, prompt):
+ prompt = prompt[0] if isinstance(prompt, list) else prompt
+ ocr_texts = (
+ re.findall(r"'([^']*)'", prompt)
+ + re.findall(r"“([^“”]*)”", prompt)
+ + re.findall(r'"([^"]*)"', prompt)
+ + re.findall(r"「([^「」]*)」", prompt)
+ )
+ return ocr_texts
+
+ def _get_glyph_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 2048,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ glyph_texts = self.get_glyph_texts(prompt)
+ input_ids = self.tokenizer(
+ glyph_texts if len(glyph_texts) > 0 else [""],
+ max_length=max_sequence_length,
+ truncation=True,
+ ).input_ids
+ input_ids = [
+ [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
+ ]
+ max_length = max(len(input_ids_) for input_ids_ in input_ids)
+ attention_mask = torch.tensor(
+ [[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
+ )
+ input_ids = torch.tensor(
+ [input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
+ device=device,
+ )
+ outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
+ glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
+
+ return glyph_embeds.to(device=device, dtype=dtype)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 2048,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ max_sequence_length (`int`, defaults to `2048`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
+
+ seq_len = prompt_embeds.size(1)
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For GLM-Image, negative_prompt must be "" instead of None
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
+
+ seq_len = negative_prompt_embeds.size(1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ if latents is not None:
+ return latents.to(device)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prior_token_ids=None,
+ prior_image_token_ids=None,
+ ):
+ if (
+ height is not None
+ and height % (self.vae_scale_factor * self.transformer.config.patch_size * 2) != 0
+ or width is not None
+ and width % (self.transformer.config.patch_size * 2) != 0
+ ):
+ # GLM-Image uses 32× downsampling, so the image dimensions must be multiples of 32.
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 4} but are {height} and {width}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if prompt is not None and prior_token_ids is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prior_token_ids is None:
+ raise ValueError(
+ "Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
+ )
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if (prior_token_ids is None and prior_image_token_ids is not None) or (
+ prior_token_ids is not None and prior_image_token_ids is None
+ ):
+ raise ValueError(
+ f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:"
+ f" {prior_image_token_ids} provided. Please make sure both are provided or neither."
+ )
+
+ if prior_token_ids is not None and prompt_embeds is None:
+ raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ image: Optional[
+ Union[
+ torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
+ ]
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 1.5,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prior_token_ids: Optional[torch.FloatTensor] = None,
+ prior_image_token_ids: Optional[torch.Tensor] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 2048,
+ ) -> Union[GlmImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. Must contain shape info in the format 'H
+ W' where H and W are token dimensions (d32). Example: "A beautiful sunset36 24"
+ generates a 1152x768 image.
+ image: Optional condition images for image-to-image generation.
+ height (`int`, *optional*):
+ The height in pixels. If not provided, derived from prompt shape info.
+ width (`int`, *optional*):
+ The width in pixels. If not provided, derived from prompt shape info.
+ num_inference_steps (`int`, *optional*, defaults to `50`):
+ The number of denoising steps for DiT.
+ guidance_scale (`float`, *optional*, defaults to `1.5`):
+ Guidance scale for classifier-free guidance.
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
+ The number of images to generate per prompt.
+ generator (`torch.Generator`, *optional*):
+ Random generator for reproducibility.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ Output format: "pil", "np", or "latent".
+
+ Examples:
+
+ Returns:
+ [`GlmImagePipelineOutput`] or `tuple`: Generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prior_token_ids,
+ prior_image_token_ids,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if batch_size != 1:
+ raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}")
+
+ device = self._execution_device
+
+ # 2. Preprocess image tokens and prompt tokens
+ if prior_token_ids is None:
+ prior_token_ids, prior_token_image_ids = self.generate_prior_tokens(
+ prompt=prompt[0] if isinstance(prompt, list) else prompt,
+ image=image,
+ height=height,
+ width=width,
+ device=device,
+ )
+
+ # 3. Preprocess image
+ if image is not None:
+ preprocessed_condition_images = []
+ for img in image:
+ image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
+ multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
+ image_height = (image_height // multiple_of) * multiple_of
+ image_width = (image_width // multiple_of) * multiple_of
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width)
+ preprocessed_condition_images.append(img)
+ height = height or image_height
+ width = width or image_width
+ image = preprocessed_condition_images
+
+ # 5. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=self.dtype,
+ )
+
+ # 4. Prepare latents and (optional) image kv cache
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=latent_channels,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+ kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
+
+ if image is not None:
+ kv_caches.set_mode("write")
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1)
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1)
+
+ latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)
+ latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)
+
+ for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids):
+ condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
+ condition_latent = retrieve_latents(
+ self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
+ )
+ condition_latent = (condition_latent - latents_mean) / latents_std
+
+ # Do not remove.
+ # It would be use to run the reference image through a
+ # forward pass at timestep 0 and keep the KV cache.
+ _ = self.transformer(
+ hidden_states=condition_latent,
+ encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
+ prior_token_id=condition_image_prior_token_id,
+ prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
+ timestep=torch.zeros((1,), device=device),
+ target_size=torch.tensor([condition_image.shape[-2:]], device=device),
+ crop_coords=torch.zeros((1, 2), device=device),
+ attention_kwargs=attention_kwargs,
+ kv_caches=kv_caches,
+ )
+
+ # 6. Prepare additional timestep conditions
+ target_size = (height, width)
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
+
+ target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
+ crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
+
+ # Prepare timesteps
+ image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
+ self.transformer.config.patch_size**2
+ )
+ timesteps = (
+ np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1]
+ if timesteps is None
+ else np.array(timesteps)
+ )
+ timesteps = timesteps.astype(np.int64).astype(np.float32)
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("base_shift", 0.25),
+ self.scheduler.config.get("max_shift", 0.75),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 7. Denoising loop
+ transformer_dtype = self.transformer.dtype
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool)
+ prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+
+ timestep = t.expand(latents.shape[0]) - 1
+
+ if image is not None:
+ kv_caches.set_mode("read")
+
+ noise_pred_cond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ prior_token_id=prior_token_ids,
+ prior_token_drop=prior_token_drop_cond,
+ timestep=timestep,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ kv_caches=kv_caches,
+ )[0].float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ if image is not None:
+ kv_caches.set_mode("skip")
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ prior_token_id=prior_token_ids,
+ prior_token_drop=prior_token_drop_uncond,
+ timestep=timestep,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ kv_caches=kv_caches,
+ )[0].float()
+
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_cond
+
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ kv_caches.clear()
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.latent_channels, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, self.vae.config.latent_channels, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return GlmImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/glm_image/pipeline_output.py b/src/diffusers/pipelines/glm_image/pipeline_output.py
new file mode 100644
index 000000000000..aec5a5454ea8
--- /dev/null
+++ b/src/diffusers/pipelines/glm_image/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class GlmImagePipelineOutput(BaseOutput):
+ """
+ Output class for CogView3 pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
index b6af23bca8fd..b41d9772a7cc 100644
--- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
+++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
@@ -53,7 +53,6 @@
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
-
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
@@ -965,14 +964,18 @@ def __call__(
# 5. Prepare timesteps
mu = calculate_shift(self.transformer.max_seq)
scheduler_kwargs = {"mu": mu}
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
if isinstance(self.scheduler, UniPCMultistepScheduler):
- self.scheduler.set_timesteps(num_inference_steps, device=device) # , shift=math.exp(mu))
+ self.scheduler.set_timesteps(num_inference_steps, device=timestep_device) # , shift=math.exp(mu))
timesteps = self.scheduler.timesteps
else:
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
**scheduler_kwargs,
)
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
index b50a6ae3ed27..6bb7a4344da5 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
@@ -728,7 +728,13 @@ def __call__(
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
# 5. Prepare latent variables
vae_dtype = self.vae.dtype
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index 5c8e295eaf4c..42ab090f1cba 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -683,7 +683,13 @@ def __call__(
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py
index 7c8468bcb109..3c7442afcaae 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py
@@ -877,8 +877,12 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
index 10a7962c258c..8c3adf33b845 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
@@ -1034,8 +1034,12 @@ def __call__(
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index 59f733a498ed..c28e358c51b6 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -881,10 +881,14 @@ def __call__(
image = self.image_processor.preprocess(image)
# 5. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
timesteps,
original_inference_steps=original_inference_steps,
strength=strength,
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
index e463884618f5..bc71d7bd171a 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
@@ -815,8 +815,16 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps
+ self.scheduler,
+ num_inference_steps,
+ timestep_device,
+ timesteps,
+ original_inference_steps=original_inference_steps,
)
# 5. Prepare latent variable
diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py
index 4d42a7049ec9..7fde18e4fbbb 100644
--- a/src/diffusers/pipelines/latte/pipeline_latte.py
+++ b/src/diffusers/pipelines/latte/pipeline_latte.py
@@ -767,7 +767,13 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
index a758d545fa4a..ca28422f9ca0 100644
--- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
+++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
@@ -260,10 +260,10 @@ def rewire_prompt(self, prompt, device):
text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
all_text.append(text)
- inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device)
+ inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device)
- self.text_encoder.to(device)
generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length)
+ generated_ids.to(device)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = self.text_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py
index 6001867916b3..05117d35d3b4 100644
--- a/src/diffusers/pipelines/ltx/__init__.py
+++ b/src/diffusers/pipelines/ltx/__init__.py
@@ -25,6 +25,7 @@
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
+ _import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
@@ -39,6 +40,7 @@
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
+ from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 8ca8b4419e18..3c90da1c7051 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -726,10 +726,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
timesteps,
sigmas=sigmas,
mu=mu,
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index 48a6f0837c8d..10c9432a7f46 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -1102,11 +1102,24 @@ def __call__(
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
+
if timesteps is None:
sigmas = linear_quadratic_schedule(num_inference_steps)
timesteps = sigmas * 1000
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ timestep_device,
+ timesteps,
+ )
sigmas = self.scheduler.sigmas
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
latent_sigma = None
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py b/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
new file mode 100644
index 000000000000..7965bd3b4b87
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
@@ -0,0 +1,1408 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKLLTXVideo
+from ...models.transformers import LTXVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler, LTXEulerAncestralRFScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LTXPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import LTXEulerAncestralRFScheduler, LTXI2VLongMultiPromptPipeline
+
+ >>> pipe = LTXI2VLongMultiPromptPipeline.from_pretrained("LTX-Video-0.9.8-13B-distilled")
+ >>> # For ComfyUI parity, swap in the RF scheduler (keeps the original config).
+ >>> pipe.scheduler = LTXEulerAncestralRFScheduler.from_config(pipe.scheduler.config)
+ >>> pipe = pipe.to("cuda").to(dtype=torch.bfloat16)
+ >>> # Example A: get decoded frames (PIL)
+ >>> out = pipe(
+ ... prompt="a chimpanzee walks | a chimpanzee eats",
+ ... num_frames=161,
+ ... height=512,
+ ... width=704,
+ ... temporal_tile_size=80,
+ ... temporal_overlap=24,
+ ... output_type="pil",
+ ... return_dict=True,
+ ... )
+ >>> frames = out.frames[0] # list of PIL.Image.Image
+ >>> # Example B: get latent video and decode later (saves VRAM during sampling)
+ >>> out_latent = pipe(prompt="a chimpanzee walking", output_type="latent", return_dict=True).frames
+ >>> frames = pipe.vae_decode_tiled(out_latent, output_type="pil")[0]
+ ```
+"""
+
+
+def get_latent_coords(
+ latent_num_frames, latent_height, latent_width, batch_size, device, rope_interpolation_scale, latent_idx
+):
+ """
+ Compute latent patch top-left coordinates in (t, y, x) order.
+
+ Args:
+ latent_num_frames: int. Number of latent frames (T_lat).
+ latent_height: int. Latent height (H_lat).
+ latent_width: int. Latent width (W_lat).
+ batch_size: int. Batch dimension (B).
+ device: torch.device for the resulting tensor.
+ rope_interpolation_scale:
+ tuple[int|float, int|float, int|float]. Scale per (t, y, x) latent step to pixel coords.
+ latent_idx: Optional[int]. When not None, shifts the time coordinate to align segments:
+ - <= 0 uses step multiples of rope_interpolation_scale[0]
+ - > 0 starts at 1 then increments by rope_interpolation_scale[0]
+
+ Returns:
+ Tensor of shape [B, 3, T_lat * H_lat * W_lat] containing top-left coordinates per latent patch, repeated for each
+ batch element.
+ """
+ latent_sample_coords = torch.meshgrid(
+ torch.arange(0, latent_num_frames, 1, device=device),
+ torch.arange(0, latent_height, 1, device=device),
+ torch.arange(0, latent_width, 1, device=device),
+ indexing="ij",
+ )
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
+ latent_coords = latent_coords.flatten(2)
+ pixel_coords = latent_coords * torch.tensor(rope_interpolation_scale, device=latent_coords.device)[None, :, None]
+ if latent_idx is not None:
+ if latent_idx <= 0:
+ frame_idx = latent_idx * rope_interpolation_scale[0]
+ else:
+ frame_idx = 1 + (latent_idx - 1) * rope_interpolation_scale[0]
+ if frame_idx == 0:
+ pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - rope_interpolation_scale[0]).clamp(min=0)
+ pixel_coords[:, 0] += frame_idx
+ return pixel_coords
+
+
+# Copied from diffusers.pipelines.ltx.pipeline_ltx.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def adain_normalize_latents(
+ curr_latents: torch.Tensor, ref_latents: Optional[torch.Tensor], factor: float
+) -> torch.Tensor:
+ """
+ Optional AdaIN normalization: channel-wise mean/variance matching of curr_latents to ref_latents, controlled by
+ factor.
+
+ Args:
+ curr_latents: Tensor [B, C, T, H, W]. Current window latents.
+ ref_latents:
+ Optional[Tensor] [B, C, T_ref, H, W]. Reference latents (e.g., first window) used to compute target stats.
+ factor: float in [0, 1]. 0 keeps current stats; 1 matches reference stats.
+
+ Returns:
+ Tensor with per-channel mean/std blended towards the reference.
+ """
+ if ref_latents is None or factor is None or factor <= 0:
+ return curr_latents
+
+ eps = torch.tensor(1e-6, device=curr_latents.device, dtype=curr_latents.dtype)
+
+ # Compute per-channel means/stds for current and reference over (T, H, W)
+ mu_curr = curr_latents.mean(dim=(2, 3, 4), keepdim=True)
+ sigma_curr = curr_latents.std(dim=(2, 3, 4), keepdim=True)
+
+ mu_ref = ref_latents.mean(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype)
+ sigma_ref = ref_latents.std(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype)
+
+ # Blend target statistics
+ mu_blend = (1.0 - float(factor)) * mu_curr + float(factor) * mu_ref
+ sigma_blend = (1.0 - float(factor)) * sigma_curr + float(factor) * sigma_ref
+ sigma_blend = torch.clamp(sigma_blend, min=float(eps))
+
+ # Apply AdaIN
+ curr_norm = (curr_latents - mu_curr) / (sigma_curr + eps)
+ return curr_norm * sigma_blend + mu_blend
+
+
+def split_into_temporal_windows(
+ latent_len: int, temporal_tile_size: int, temporal_overlap: int, compression: int
+) -> List[Tuple[int, int]]:
+ """
+ Split latent frames into sliding windows.
+
+ Args:
+ latent_len: int. Number of latent frames (T_lat).
+ temporal_tile_size: int. Window size in latent frames (> 0).
+ temporal_overlap: int. Overlap between windows in latent frames (>= 0).
+ compression: int. VAE temporal compression ratio (unused here; kept for parity).
+
+ Returns:
+ list[tuple[int, int]]: inclusive-exclusive (start, end) indices per window.
+ """
+ if temporal_tile_size <= 0:
+ raise ValueError("temporal_tile_size must be > 0")
+ stride = max(temporal_tile_size - temporal_overlap, 1)
+ windows = []
+ start = 0
+ while start < latent_len:
+ end = min(start + temporal_tile_size, latent_len)
+ windows.append((start, end))
+ if end == latent_len:
+ break
+ start = start + stride
+ return windows
+
+
+def linear_overlap_fuse(prev: torch.Tensor, new: torch.Tensor, overlap: int) -> torch.Tensor:
+ """
+ Temporal linear crossfade between two latent clips over the overlap region.
+
+ Args:
+ prev: Tensor [B, C, F, H, W]. Previous output segment.
+ new: Tensor [B, C, F, H, W]. New segment to be appended.
+ overlap: int. Number of frames to crossfade (overlap <= 1 concatenates without blend).
+
+ Returns:
+ Tensor [B, C, F_prev + F_new - overlap, H, W] after crossfade at the seam.
+ """
+ if overlap <= 1:
+ return torch.cat([prev, new], dim=2)
+ alpha = torch.linspace(1, 0, overlap + 2, device=prev.device, dtype=prev.dtype)[1:-1]
+ shape = [1] * prev.ndim
+ shape[2] = alpha.size(0)
+ alpha = alpha.reshape(shape)
+ blended = alpha * prev[:, :, -overlap:] + (1 - alpha) * new[:, :, :overlap]
+ return torch.cat([prev[:, :, :-overlap], blended, new[:, :, overlap:]], dim=2)
+
+
+def inject_prev_tail_latents(
+ window_latents: torch.Tensor,
+ prev_tail_latents: Optional[torch.Tensor],
+ window_cond_mask_5d: torch.Tensor,
+ overlap_lat: int,
+ strength: Optional[float],
+ prev_overlap_len: int,
+) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Inject the tail latents from the previous window at the beginning of the current window (first k frames), where k =
+ min(overlap_lat, T_curr, T_prev_tail).
+
+ Args:
+ window_latents: Tensor [B, C, T, H, W]. Current window latents.
+ prev_tail_latents: Optional[Tensor] [B, C, T_prev, H, W]. Tail segment from the previous window.
+ window_cond_mask_5d: Tensor [B, 1, T, H, W]. Per-token conditioning mask (1 = free, 0 = hard condition).
+ overlap_lat: int. Number of latent frames to inject from the previous tail.
+ strength: Optional[float] in [0, 1]. Blend strength; 1.0 replaces, 0.0 keeps original.
+ prev_overlap_len: int. Accumulated overlap length so far (used for trimming later).
+
+ Returns:
+ Tuple[Tensor, Tensor, int]: (updated_window_latents, updated_cond_mask, updated_prev_overlap_len)
+ """
+ if prev_tail_latents is None or overlap_lat <= 0 or strength is None or strength <= 0:
+ return window_latents, window_cond_mask_5d, prev_overlap_len
+
+ # Expected shape: [B, C, T, H, W]
+ T = int(window_latents.shape[2])
+ k = min(int(overlap_lat), T, int(prev_tail_latents.shape[2]))
+ if k <= 0:
+ return window_latents, window_cond_mask_5d, prev_overlap_len
+
+ tail = prev_tail_latents[:, :, -k:]
+ mask = torch.full(
+ (window_cond_mask_5d.shape[0], 1, tail.shape[2], window_cond_mask_5d.shape[3], window_cond_mask_5d.shape[4]),
+ 1.0 - strength,
+ dtype=window_cond_mask_5d.dtype,
+ device=window_cond_mask_5d.device,
+ )
+
+ window_latents = torch.cat([window_latents, tail], dim=2)
+ window_cond_mask_5d = torch.cat([window_cond_mask_5d, mask], dim=2)
+ return window_latents, window_cond_mask_5d, prev_overlap_len + k
+
+
+def build_video_coords_for_window(
+ latents: torch.Tensor,
+ overlap_len: int,
+ guiding_len: int,
+ negative_len: int,
+ rope_interpolation_scale: torch.Tensor,
+ frame_rate: int,
+) -> torch.Tensor:
+ """
+ Build video_coords: [B, 3, S] with order [t, y, x].
+
+ Args:
+ latents: Tensor [B, C, T, H, W]. Current window latents (before any trimming).
+ overlap_len: int. Number of frames from previous tail injected at the head.
+ guiding_len: int. Number of guidance frames appended at the head.
+ negative_len: int. Number of negative-index frames appended at the head (typically 1 or 0).
+ rope_interpolation_scale: tuple[int|float, int|float, int|float]. Scale for (t, y, x).
+ frame_rate: int. Used to convert time indices into seconds (t /= frame_rate).
+
+ Returns:
+ Tensor [B, 3, T*H*W] of fractional pixel coordinates per latent patch.
+ """
+
+ b, c, f, h, w = latents.shape
+ pixel_coords = get_latent_coords(f, h, w, b, latents.device, rope_interpolation_scale, 0)
+ replace_corrds = []
+ if overlap_len > 0:
+ replace_corrds.append(get_latent_coords(overlap_len, h, w, b, latents.device, rope_interpolation_scale, 0))
+ if guiding_len > 0:
+ replace_corrds.append(
+ get_latent_coords(guiding_len, h, w, b, latents.device, rope_interpolation_scale, overlap_len)
+ )
+ if negative_len > 0:
+ replace_corrds.append(get_latent_coords(negative_len, h, w, b, latents.device, rope_interpolation_scale, -1))
+ if len(replace_corrds) > 0:
+ replace_corrds = torch.cat(replace_corrds, axis=2)
+ pixel_coords[:, :, -replace_corrds.shape[2] :] = replace_corrds
+ fractional_coords = pixel_coords.to(torch.float32)
+ fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
+ return fractional_coords
+
+
+def parse_prompt_segments(prompt: Union[str, List[str]], prompt_segments: Optional[List[Dict[str, Any]]]) -> List[str]:
+ """
+ Return a list of positive prompts per window index.
+
+ Args:
+ prompt: str | list[str]. If str contains '|', parts are split by bars and trimmed.
+ prompt_segments:
+ list[dict], optional. Each dict with {"start_window", "end_window", "text"} overrides prompts per window.
+
+ Returns:
+ list[str] containing the positive prompt for each window index.
+ """
+ if prompt is None:
+ return []
+ if prompt_segments:
+ max_w = 0
+ for seg in prompt_segments:
+ max_w = max(max_w, int(seg.get("end_window", 0)))
+ texts = [""] * (max_w + 1)
+ for seg in prompt_segments:
+ s = int(seg.get("start_window", 0))
+ e = int(seg.get("end_window", s))
+ txt = seg.get("text", "")
+ for w in range(s, e + 1):
+ texts[w] = txt
+ # fill empty by last non-empty
+ last = ""
+ for i in range(len(texts)):
+ if texts[i] == "":
+ texts[i] = last
+ else:
+ last = texts[i]
+ return texts
+
+ # bar-split mode
+ if isinstance(prompt, str):
+ parts = [p.strip() for p in prompt.split("|")]
+ else:
+ parts = prompt
+ parts = [p for p in parts if p is not None]
+ return parts
+
+
+def batch_normalize(latents, reference, factor):
+ """
+ Batch AdaIN-like normalization for latents in dict format (ComfyUI-compatible).
+
+ Args:
+ latents: dict containing "samples" shaped [B, C, F, H, W]
+ reference: dict containing "samples" used to compute target stats
+ factor: float in [0, 1]; 0 = no change, 1 = full match to reference
+ Returns:
+ Tuple[dict]: a single-element tuple with the updated latents dict.
+ """
+ latents_copy = copy.deepcopy(latents)
+ t = latents_copy["samples"] # B x C x F x H x W
+
+ for i in range(t.size(0)): # batch
+ for c in range(t.size(1)): # channel
+ r_sd, r_mean = torch.std_mean(reference["samples"][i, c], dim=None) # index by original dim order
+ i_sd, i_mean = torch.std_mean(t[i, c], dim=None)
+
+ t[i, c] = ((t[i, c] - i_mean) / i_sd) * r_sd + r_mean
+
+ latents_copy["samples"] = torch.lerp(latents["samples"], t, factor)
+ return (latents_copy,)
+
+
+class LTXI2VLongMultiPromptPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ r"""
+ Long-duration I2V (image-to-video) multi-prompt pipeline with ComfyUI parity.
+
+ Key features:
+ - Temporal sliding-window sampling only (no spatial H/W sharding); autoregressive fusion across windows.
+ - Multi-prompt segmentation per window with smooth transitions at window heads.
+ - First-frame hard conditioning via per-token mask for I2V.
+ - VRAM control via temporal windowing and VAE tiled decoding.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ Args:
+ transformer ([`LTXVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`] or [`LTXEulerAncestralRFScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLLTXVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTXVideo,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: LTXVideoTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ if not isinstance(scheduler, LTXEulerAncestralRFScheduler):
+ logger.warning(
+ "For ComfyUI parity, `LTXI2VLongMultiPromptPipeline` is typically run with "
+ "`LTXEulerAncestralRFScheduler`. Got %s.",
+ scheduler.__class__.__name__,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
+ )
+
+ self.default_height = 512
+ self.default_width = 704
+ self.default_frames = 121
+ self._current_tile_T = None
+
+ @property
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.guidance_scale
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.guidance_rescale
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.do_classifier_free_guidance
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.num_timesteps
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.current_timestep
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.attention_kwargs
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.interrupt
+ def interrupt(self):
+ return self._interrupt
+
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ generator: Optional[torch.Generator],
+ dtype: torch.dtype = torch.float32,
+ latents: Optional[torch.Tensor] = None,
+ cond_latents: Optional[torch.Tensor] = None,
+ cond_strength: float = 0.0,
+ negative_index_latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], int, int, int]:
+ """
+ Prepare base latents and optionally inject first-frame conditioning latents.
+
+ Returns:
+ latents, negative_index_latents, latent_num_frames, latent_height, latent_width
+ """
+ if latents is None:
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ latents = torch.zeros(
+ (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width),
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ latent_num_frames = latents.shape[2]
+ latent_height = latents.shape[3]
+ latent_width = latents.shape[4]
+ latents = latents.to(device=device, dtype=dtype)
+
+ if cond_latents is not None and cond_strength > 0:
+ if negative_index_latents is None:
+ negative_index_latents = cond_latents
+ latents[:, :, :1, :, :] = cond_latents
+
+ return latents, negative_index_latents, latent_num_frames, latent_height, latent_width
+
+ # TODO: refactor this out
+ @torch.no_grad()
+ def vae_decode_tiled(
+ self,
+ latents: torch.Tensor,
+ decode_timestep: Optional[float] = None,
+ decode_noise_scale: Optional[float] = None,
+ horizontal_tiles: int = 4,
+ vertical_tiles: int = 4,
+ overlap: int = 3,
+ last_frame_fix: bool = True,
+ generator: Optional[torch.Generator] = None,
+ output_type: str = "pt",
+ auto_denormalize: bool = True,
+ compute_dtype: torch.dtype = torch.float32,
+ enable_vae_tiling: bool = False,
+ ) -> Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]:
+ """
+ VAE-based spatial tiled decoding (ComfyUI parity) implemented in Diffusers style.
+ - Linearly feather and blend overlapping tiles to avoid seams.
+ - Optional last_frame_fix: duplicate the last latent frame before decoding, then drop time_scale_factor frames
+ at the end.
+ - Supports timestep_conditioning and decode_noise_scale injection.
+ - By default, "normalized latents" (the denoising output) are de-normalized internally (auto_denormalize=True).
+ - Tile fusion is computed in compute_dtype (float32 by default) to reduce blur and color shifts.
+
+ Args:
+ latents: [B, C_latent, F_latent, H_latent, W_latent]
+ decode_timestep: Optional decode timestep (effective only if VAE supports timestep_conditioning)
+ decode_noise_scale:
+ Optional decode noise interpolation (effective only if VAE supports timestep_conditioning)
+ horizontal_tiles, vertical_tiles: Number of tiles horizontally/vertically (>= 1)
+ overlap: Overlap in latent space (in latent pixels, >= 0)
+ last_frame_fix: Whether to enable the "repeat last frame" fix
+ generator: Random generator (used for decode_noise_scale noise)
+ output_type: "latent" | "pt" | "np" | "pil"
+ - "latent": return latents unchanged (useful for downstream processing)
+ - "pt": return tensor in VAE output space
+ - "np"/"pil": post-processed outputs via VideoProcessor.postprocess_video
+ auto_denormalize: If True, apply LTX de-normalization to `latents` internally (recommended)
+ compute_dtype: Precision used during tile fusion (float32 default; significantly reduces seam blur)
+ enable_vae_tiling: If True, delegate tiling to VAE's built-in `tiled_decode` (sets `vae.use_tiling`).
+
+ Returns:
+ - If output_type="latent": returns input `latents` unchanged
+ - If output_type="pt": returns [B, C, F, H, W] (values roughly in [-1, 1])
+ - If output_type="np"/"pil": returns post-processed outputs via postprocess_video
+ """
+ if output_type == "latent":
+ return latents
+ if horizontal_tiles < 1 or vertical_tiles < 1:
+ raise ValueError("horizontal_tiles and vertical_tiles must be >= 1")
+ overlap = max(int(overlap), 0)
+
+ # Device and precision
+ device = self._execution_device
+ latents = latents.to(device=device, dtype=compute_dtype)
+
+ # De-normalize to VAE space (avoid color artifacts)
+ if auto_denormalize:
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ # dtype required for VAE forward pass
+ latents = latents.to(dtype=self.vae.dtype)
+
+ # Temporal/spatial upscaling ratios (parity with ComfyUI's downscale_index_formula)
+ tsf = int(self.vae_temporal_compression_ratio)
+ sf = int(self.vae_spatial_compression_ratio)
+
+ # Optional: last_frame_fix (repeat last latent frame)
+ if last_frame_fix:
+ latents = torch.cat([latents, latents[:, :, -1:].contiguous()], dim=2)
+
+ b, c_lat, f_lat, h_lat, w_lat = latents.shape
+ f_out = 1 + (f_lat - 1) * tsf
+ h_out = h_lat * sf
+ w_out = w_lat * sf
+
+ # timestep_conditioning + decode-time noise injection (aligned with pipeline)
+ if getattr(self.vae.config, "timestep_conditioning", False):
+ dt = float(decode_timestep) if decode_timestep is not None else 0.0
+ vt = torch.tensor([dt], device=device, dtype=latents.dtype)
+ if decode_noise_scale is not None:
+ dns = torch.tensor([float(decode_noise_scale)], device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ latents = (1 - dns) * latents + dns * noise
+ else:
+ vt = None
+
+ if enable_vae_tiling and hasattr(self.vae, "enable_tiling"):
+ self.vae.enable_tiling()
+ decoded = self.vae.decode(latents, vt, return_dict=False)[0]
+ if last_frame_fix:
+ decoded = decoded[:, :, :-tsf, :, :]
+ if output_type in ("np", "pil"):
+ return self.video_processor.postprocess_video(decoded, output_type=output_type)
+ return decoded
+
+ # Compute base tile sizes (in latent space)
+ base_tile_h = (h_lat + (vertical_tiles - 1) * overlap) // vertical_tiles
+ base_tile_w = (w_lat + (horizontal_tiles - 1) * overlap) // horizontal_tiles
+
+ output: Optional[torch.Tensor] = None # [B, C_img, F, H, W], fused using compute_dtype
+ weights: Optional[torch.Tensor] = None # [B, 1, F, H, W], fused using compute_dtype
+
+ # Iterate tiles in latent space (no temporal tiling)
+ for v in range(vertical_tiles):
+ for h in range(horizontal_tiles):
+ h_start = h * (base_tile_w - overlap)
+ v_start = v * (base_tile_h - overlap)
+
+ h_end = min(h_start + base_tile_w, w_lat) if h < horizontal_tiles - 1 else w_lat
+ v_end = min(v_start + base_tile_h, h_lat) if v < vertical_tiles - 1 else h_lat
+
+ # Slice latent tile and decode
+ tile_latents = latents[:, :, :, v_start:v_end, h_start:h_end]
+ decoded_tile = self.vae.decode(tile_latents, vt, return_dict=False)[0] # [B, C, F, Ht, Wt]
+ # Cast to high precision to reduce blending blur
+ decoded_tile = decoded_tile.to(dtype=compute_dtype)
+
+ # Initialize output buffers (compute_dtype)
+ if output is None:
+ output = torch.zeros(
+ (b, decoded_tile.shape[1], f_out, h_out, w_out),
+ device=decoded_tile.device,
+ dtype=compute_dtype,
+ )
+ weights = torch.zeros(
+ (b, 1, f_out, h_out, w_out),
+ device=decoded_tile.device,
+ dtype=compute_dtype,
+ )
+
+ # Tile placement in output pixel space
+ out_h_start = v_start * sf
+ out_h_end = v_end * sf
+ out_w_start = h_start * sf
+ out_w_end = h_end * sf
+
+ tile_out_h = out_h_end - out_h_start
+ tile_out_w = out_w_end - out_w_start
+
+ # Linear feathering weights [B, 1, F, Ht, Wt] (compute_dtype)
+ tile_weights = torch.ones(
+ (b, 1, decoded_tile.shape[2], tile_out_h, tile_out_w),
+ device=decoded_tile.device,
+ dtype=compute_dtype,
+ )
+
+ overlap_out_h = overlap * sf
+ overlap_out_w = overlap * sf
+
+ # Horizontal feathering: left/right overlaps
+ if overlap_out_w > 0:
+ if h > 0:
+ h_blend = torch.linspace(
+ 0, 1, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype
+ )
+ tile_weights[:, :, :, :, :overlap_out_w] *= h_blend.view(1, 1, 1, 1, -1)
+ if h < horizontal_tiles - 1:
+ h_blend = torch.linspace(
+ 1, 0, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype
+ )
+ tile_weights[:, :, :, :, -overlap_out_w:] *= h_blend.view(1, 1, 1, 1, -1)
+
+ # Vertical feathering: top/bottom overlaps
+ if overlap_out_h > 0:
+ if v > 0:
+ v_blend = torch.linspace(
+ 0, 1, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype
+ )
+ tile_weights[:, :, :, :overlap_out_h, :] *= v_blend.view(1, 1, 1, -1, 1)
+ if v < vertical_tiles - 1:
+ v_blend = torch.linspace(
+ 1, 0, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype
+ )
+ tile_weights[:, :, :, -overlap_out_h:, :] *= v_blend.view(1, 1, 1, -1, 1)
+
+ # Accumulate blended tile
+ output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += decoded_tile * tile_weights
+ weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += tile_weights
+
+ # Normalize, then clamp to [-1, 1] in compute_dtype to avoid color artifacts
+ output = output / (weights + 1e-8)
+ output = output.clamp(-1.0, 1.0)
+ output = output.to(dtype=self.vae.dtype)
+
+ # Optional: drop the last tsf frames after last_frame_fix
+ if last_frame_fix:
+ output = output[:, :, :-tsf, :, :]
+
+ if output_type in ("np", "pil"):
+ return self.video_processor.postprocess_video(output, output_type=output_type)
+ return output
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_segments: Optional[List[Dict[str, Any]]] = None,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ frame_rate: float = 25,
+ guidance_scale: float = 1.0,
+ guidance_rescale: float = 0.0,
+ num_inference_steps: Optional[int] = 8,
+ sigmas: Optional[Union[List[float], torch.Tensor]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ seed: Optional[int] = 0,
+ cond_image: Optional[Union["PIL.Image.Image", torch.Tensor]] = None,
+ cond_strength: float = 0.5,
+ latents: Optional[torch.Tensor] = None,
+ temporal_tile_size: int = 80,
+ temporal_overlap: int = 24,
+ temporal_overlap_cond_strength: float = 0.5,
+ adain_factor: float = 0.25,
+ guidance_latents: Optional[torch.Tensor] = None,
+ guiding_strength: float = 1.0,
+ negative_index_latents: Optional[torch.Tensor] = None,
+ negative_index_strength: float = 1.0,
+ skip_steps_sigma_threshold: Optional[float] = 1,
+ decode_timestep: Optional[float] = 0.05,
+ decode_noise_scale: Optional[float] = 0.025,
+ decode_horizontal_tiles: int = 4,
+ decode_vertical_tiles: int = 4,
+ decode_overlap: int = 3,
+ output_type: Optional[str] = "latent", # "latent" | "pt" | "np" | "pil"
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ ):
+ r"""
+ Generate an image-to-video sequence via temporal sliding windows and multi-prompt scheduling.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ Positive text prompt(s) per window. If a single string contains '|', parts are split by bars.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ Negative prompt(s) to suppress undesired content.
+ prompt_segments (`List[dict]`, *optional*):
+ Segment mapping with {"start_window", "end_window", "text"} to override prompts per window.
+ height (`int`, defaults to `512`):
+ Output image height in pixels; must be divisible by 32.
+ width (`int`, defaults to `704`):
+ Output image width in pixels; must be divisible by 32.
+ num_frames (`int`, defaults to `161`):
+ Number of output frames (in decoded pixel space).
+ frame_rate (`float`, defaults to `25`):
+ Frames-per-second; used to normalize temporal coordinates in `video_coords`.
+ guidance_scale (`float`, defaults to `1.0`):
+ CFG scale; values > 1 enable classifier-free guidance.
+ guidance_rescale (`float`, defaults to `0.0`):
+ Optional rescale to mitigate overexposure under CFG (see `rescale_noise_cfg`).
+ num_inference_steps (`int`, *optional*, defaults to `8`):
+ Denoising steps per window. Ignored if `sigmas` is provided.
+ sigmas (`List[float]` or `torch.Tensor`, *optional*):
+ Explicit sigma schedule per window; if set, overrides `num_inference_steps`.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ Controls stochasticity; list accepted but first element is used (batch=1).
+ seed (`int`, *optional*, defaults to `0`):
+ If provided, seeds the shared generator for global latents and derives a window-local generator with
+ `seed + w_start` per temporal window.
+ cond_image (`PIL.Image.Image` or `torch.Tensor`, *optional*):
+ Conditioning image; fixes frame 0 via per-token mask when `cond_strength > 0`.
+ cond_strength (`float`, defaults to `0.5`):
+ Strength of first-frame hard conditioning (smaller cond_mask ⇒ stronger preservation).
+ latents (`torch.Tensor`, *optional*):
+ Initial latents [B, C_lat, F_lat, H_lat, W_lat]; if None, sampled with `randn_tensor`.
+ temporal_tile_size (`int`, defaults to `80`):
+ Temporal window size (in decoded frames); internally scaled by VAE temporal compression.
+ temporal_overlap (`int`, defaults to `24`):
+ Overlap between consecutive windows (in decoded frames); internally scaled by compression.
+ temporal_overlap_cond_strength (`float`, defaults to `0.5`):
+ Strength for injecting previous window tail latents at new window head.
+ adain_factor (`float`, defaults to `0.25`):
+ AdaIN normalization strength for cross-window consistency (0 disables).
+ guidance_latents (`torch.Tensor`, *optional*):
+ Reference latents injected at window head; length trimmed by overlap for subsequent windows.
+ guiding_strength (`float`, defaults to `1.0`):
+ Injection strength for `guidance_latents`.
+ negative_index_latents (`torch.Tensor`, *optional*):
+ A single-frame latent appended at window head for "negative index" semantics.
+ negative_index_strength (`float`, defaults to `1.0`):
+ Injection strength for `negative_index_latents`.
+ skip_steps_sigma_threshold (`float`, *optional*, defaults to `1`):
+ Skip steps whose sigma exceeds this threshold.
+ decode_timestep (`float`, *optional*, defaults to `0.05`):
+ Decode-time timestep (if VAE supports timestep_conditioning).
+ decode_noise_scale (`float`, *optional*, defaults to `0.025`):
+ Decode-time noise mix scale (if VAE supports timestep_conditioning).
+ decode_horizontal_tiles (`int`, defaults to `4`):
+ Number of horizontal tiles during VAE decoding.
+ decode_vertical_tiles (`int`, defaults to `4`):
+ Number of vertical tiles during VAE decoding.
+ decode_overlap (`int`, defaults to `3`):
+ Overlap (in latent pixels) between tiles during VAE decoding.
+ output_type (`str`, *optional*, defaults to `"latent"`):
+ The output format of the generated video. Choose between "latent", "pt", "np", or "pil". If "latent",
+ returns latents without decoding.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ Extra attention parameters forwarded to the transformer.
+ callback_on_step_end (`PipelineCallback` or `MultiPipelineCallbacks`, *optional*):
+ Per-step callback hook.
+ callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`):
+ Keys from locals() to pass into the callback.
+ max_sequence_length (`int`, defaults to `128`):
+ Tokenizer max length for prompt encoding.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated frames. The output format depends on
+ `output_type`:
+ - "latent"/"pt": `torch.Tensor` [B, C, F, H, W]; "latent" is in normalized latent space, "pt" is VAE
+ output space.
+ - "np": `np.ndarray` post-processed.
+ - "pil": `List[PIL.Image.Image]` list of PIL images.
+
+ Shapes:
+ Latent sizes (when auto-generated):
+ - F_lat = (num_frames - 1) // vae_temporal_compression_ratio + 1
+ - H_lat = height // vae_spatial_compression_ratio
+ - W_lat = width // vae_spatial_compression_ratio
+
+ Notes:
+ - Seeding: when `seed` is provided, each temporal window uses a local generator seeded with `seed +
+ w_start`, while the shared generator is seeded once for global latents if no generator is passed;
+ otherwise the passed-in generator is reused.
+ - CFG: unified `noise_pred = uncond + w * (text - uncond)` with optional `guidance_rescale`.
+ - Memory: denoising performs full-frame predictions (no spatial tiling); decoding can be tiled to avoid
+ OOM.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. Input validation: height/width must be divisible by 32
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ # 1. Device & generator
+ device = self._execution_device
+ # Normalize generator input: accept list but use the first (batch_size=1)
+ if isinstance(generator, list):
+ generator = generator[0]
+ if seed is not None and generator is None:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ # 2. Optional i2v first frame conditioning: encode cond_image and inject at frame 0 via prepare_latents
+ cond_latents = None
+ if cond_image is not None and cond_strength > 0:
+ img = self.video_processor.preprocess(cond_image, height=height, width=width)
+ img = img.to(device=device, dtype=self.vae.dtype)
+ enc = self.vae.encode(img.unsqueeze(2)) # [B, C, 1, h, w]
+ cond_latents = enc.latent_dist.mode() if hasattr(enc, "latent_dist") else enc.latents
+ cond_latents = cond_latents.to(torch.float32)
+ cond_latents = self._normalize_latents(
+ cond_latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+
+ # 3. Global initial latents [B,C,F,H,W], optionally seeded/conditioned
+ latents, negative_index_latents, latent_num_frames, latent_height, latent_width = self.prepare_latents(
+ batch_size=1,
+ num_channels_latents=self.transformer.config.in_channels,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ device=device,
+ generator=generator,
+ dtype=torch.float32,
+ latents=latents,
+ cond_latents=cond_latents,
+ cond_strength=cond_strength,
+ negative_index_latents=negative_index_latents,
+ )
+ if guidance_latents is not None:
+ guidance_latents = guidance_latents.to(device=device, dtype=torch.float32)
+ if latents.shape[2] != guidance_latents.shape[2]:
+ raise ValueError("The number of frames in `latents` and `guidance_latents` must be the same")
+
+ # 4. Sliding windows in latent frames
+ tile_size_lat = max(1, temporal_tile_size // self.vae_temporal_compression_ratio)
+ overlap_lat = max(0, temporal_overlap // self.vae_temporal_compression_ratio)
+ windows = split_into_temporal_windows(
+ latent_num_frames, tile_size_lat, overlap_lat, self.vae_temporal_compression_ratio
+ )
+
+ # 5. Multi-prompt segments parsing
+ segment_texts = parse_prompt_segments(prompt, prompt_segments)
+
+ out_latents = None
+ first_window_latents = None
+
+ # 6. Process each temporal window
+ for w_idx, (w_start, w_end) in enumerate(windows):
+ if self.interrupt:
+ break
+
+ # 6.1 Encode prompt embeddings per window segment
+ seg_index = min(w_idx, len(segment_texts) - 1) if segment_texts else 0
+ pos_text = segment_texts[seg_index] if segment_texts else (prompt if isinstance(prompt, str) else "")
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=[pos_text],
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=1,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=None,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 6.2 Window-level timesteps reset: fresh sampling for each temporal window
+ if sigmas is not None:
+ s = torch.tensor(sigmas, dtype=torch.float32) if not isinstance(sigmas, torch.Tensor) else sigmas
+ self.scheduler.set_timesteps(sigmas=s, device=device)
+ self._num_timesteps = len(sigmas)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device)
+ self._num_timesteps = num_inference_steps
+
+ # 6.3 Extract window latents [B,C,T,H,W]
+ window_latents = latents[:, :, w_start:w_end]
+ window_guidance_latents = guidance_latents[:, :, w_start:w_end] if guidance_latents is not None else None
+ window_T = window_latents.shape[2]
+
+ # 6.4 Build per-window cond mask and inject previous tails / reference
+ window_cond_mask_5d = torch.ones(
+ (1, 1, window_T, latent_height, latent_width), device=device, dtype=torch.float32
+ )
+ self._current_tile_T = window_T
+ prev_overlap_len = 0
+ # Inter-window tail latent injection (Extend)
+ if w_idx > 0 and overlap_lat > 0 and out_latents is not None:
+ k = min(overlap_lat, out_latents.shape[2])
+ prev_tail = out_latents[:, :, -k:]
+ window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents(
+ window_latents,
+ prev_tail,
+ window_cond_mask_5d,
+ overlap_lat,
+ temporal_overlap_cond_strength,
+ prev_overlap_len,
+ )
+ # Reference/negative-index latent injection (append 1 frame at window head; controlled by negative_index_strength)
+ if window_guidance_latents is not None:
+ guiding_len = (
+ window_guidance_latents.shape[2] if w_idx == 0 else window_guidance_latents.shape[2] - overlap_lat
+ )
+ window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents(
+ window_latents,
+ window_guidance_latents[:, :, -guiding_len:],
+ window_cond_mask_5d,
+ guiding_len,
+ guiding_strength,
+ prev_overlap_len,
+ )
+ else:
+ guiding_len = 0
+ window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents(
+ window_latents,
+ negative_index_latents,
+ window_cond_mask_5d,
+ 1,
+ negative_index_strength,
+ prev_overlap_len,
+ )
+ if w_idx == 0 and cond_image is not None and cond_strength > 0:
+ # First-frame I2V: smaller mask means stronger preservation of the original latent
+ window_cond_mask_5d[:, :, 0] = 1.0 - cond_strength
+
+ # Update effective window latent sizes (consider injections on T/H/W)
+ w_B, w_C, w_T_eff, w_H_eff, w_W_eff = window_latents.shape
+ p = self.transformer_spatial_patch_size
+ pt = self.transformer_temporal_patch_size
+
+ # 6.5 Pack full-window latents/masks once
+ # Seeding policy: derive a window-local generator to decouple RNG across windows
+ if seed is not None:
+ tile_seed = int(seed) + int(w_start)
+ local_gen = torch.Generator(device=device).manual_seed(tile_seed)
+ else:
+ local_gen = generator
+ # randn*mask + (1-mask)*latents implements hard-condition initialization
+ init_rand = randn_tensor(window_latents.shape, generator=local_gen, device=device, dtype=torch.float32)
+ mixed_latents = init_rand * window_cond_mask_5d + (1 - window_cond_mask_5d) * window_latents
+ window_latents_packed = self._pack_latents(
+ window_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ latents_packed = self._pack_latents(
+ mixed_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ cond_mask_tokens = self._pack_latents(
+ window_cond_mask_5d, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ if self.do_classifier_free_guidance:
+ cond_mask = torch.cat([cond_mask_tokens, cond_mask_tokens], dim=0)
+ else:
+ cond_mask = cond_mask_tokens
+
+ # 6.6 Denoising loop per full window (no spatial tiling)
+ sigmas_current = self.scheduler.sigmas.to(device=latents_packed.device)
+ if sigmas_current.shape[0] >= 2:
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[:-1])):
+ if self.interrupt:
+ break
+ # Skip semantics: if sigma exceeds threshold, skip this step (do not call scheduler.step)
+ sigma_val = float(sigmas_current[i].item())
+ if skip_steps_sigma_threshold is not None and float(skip_steps_sigma_threshold) > 0.0:
+ if sigma_val > float(skip_steps_sigma_threshold):
+ continue
+
+ self._current_timestep = t
+
+ # Model input (stack 2 copies under CFG)
+ latent_model_input = (
+ torch.cat([latents_packed] * 2) if self.do_classifier_free_guidance else latents_packed
+ )
+ # Broadcast timesteps, combine with per-token cond mask (I2V at window head)
+ timestep = t.expand(latent_model_input.shape[0])
+ if cond_mask is not None:
+ # Broadcast timestep to per-token mask under CFG: [B] -> [B, S, 1]
+ timestep = timestep[:, None, None] * cond_mask
+
+ # Micro-conditions: only provide video_coords (num_frames/height/width set to 1)
+ rope_interpolation_scale = (
+ self.vae_temporal_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+ # Inpainting pre-blend (ComfyUI parity: KSamplerX0Inpaint:400)
+ if cond_mask_tokens is not None:
+ latents_packed = latents_packed * cond_mask_tokens + window_latents_packed * (
+ 1.0 - cond_mask_tokens
+ )
+
+ # Negative-index/overlap lengths (for segmenting time coordinates; RoPE-compatible)
+ k_negative_count = (
+ 1 if (negative_index_latents is not None and float(negative_index_strength) > 0.0) else 0
+ )
+ k_overlap_count = overlap_lat if (w_idx > 0 and overlap_lat > 0) else 0
+ video_coords = build_video_coords_for_window(
+ latents=window_latents,
+ overlap_len=int(k_overlap_count),
+ guiding_len=int(guiding_len),
+ negative_len=int(k_negative_count),
+ rope_interpolation_scale=rope_interpolation_scale,
+ frame_rate=frame_rate,
+ )
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input.to(dtype=self.transformer.dtype),
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=1,
+ height=1,
+ width=1,
+ rope_interpolation_scale=rope_interpolation_scale,
+ video_coords=video_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # Unified CFG
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ if self.guidance_rescale > 0:
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
+ # Use global timestep for scheduling, but apply suppressive blending with hard-condition tokens (e.g., first frame) after step to avoid brightness/flicker due to time misalignment
+ latents_packed = self.scheduler.step(
+ noise_pred, t, latents_packed, generator=local_gen, return_dict=False
+ )[0]
+ # Inpainting post-blend (ComfyUI parity: restore hard-conditioned regions after update)
+ if cond_mask_tokens is not None:
+ latents_packed = latents_packed * cond_mask_tokens + window_latents_packed * (
+ 1.0 - cond_mask_tokens
+ )
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents_packed = callback_outputs.pop("latents", latents_packed)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+ else:
+ # Not enough sigmas to perform a valid step; skip this window safely.
+ pass
+
+ # 6.7 Unpack back to [B,C,T,H,W] once
+ window_out = self._unpack_latents(
+ latents_packed,
+ w_T_eff,
+ w_H_eff,
+ w_W_eff,
+ p,
+ pt,
+ )
+ if prev_overlap_len > 0:
+ window_out = window_out[:, :, :-prev_overlap_len]
+
+ # 6.8 Overlap handling and fusion
+ if out_latents is None:
+ # First window: keep all latent frames and cache as AdaIN reference
+ out_latents = window_out
+ first_window_latents = out_latents
+ else:
+ window_out = window_out[:, :, 1:] # Drop the first frame of the new window
+ if adain_factor > 0 and first_window_latents is not None:
+ window_out = adain_normalize_latents(window_out, first_window_latents, adain_factor)
+ overlap_len = max(overlap_lat - 1, 1)
+ prev_tail_chunk = out_latents[:, :, -window_out.shape[2] :]
+ fused = linear_overlap_fuse(prev_tail_chunk, window_out, overlap_len)
+ out_latents = torch.cat([out_latents[:, :, : -window_out.shape[2]], fused], dim=2)
+
+ # 7. Decode or return latent
+ if output_type == "latent":
+ video = out_latents
+ else:
+ # Decode via tiling to avoid OOM from full-frame decoding; latents are already de-normalized, so keep auto_denormalize disabled
+ video = self.vae_decode_tiled(
+ out_latents,
+ decode_timestep=decode_timestep,
+ decode_noise_scale=decode_noise_scale,
+ horizontal_tiles=int(decode_horizontal_tiles),
+ vertical_tiles=int(decode_vertical_tiles),
+ overlap=int(decode_overlap),
+ generator=generator,
+ output_type=output_type, # Keep type consistent; postprocess is applied afterwards
+ )
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index f30f8a3dc8f6..3226b045cccb 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -798,10 +798,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
timesteps,
sigmas=sigmas,
mu=mu,
diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py
new file mode 100644
index 000000000000..115e83e827a4
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/__init__.py
@@ -0,0 +1,58 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["connectors"] = ["LTX2TextConnectors"]
+ _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
+ _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
+ _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
+ _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
+ _import_structure["vocoder"] = ["LTX2Vocoder"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .connectors import LTX2TextConnectors
+ from .latent_upsampler import LTX2LatentUpsamplerModel
+ from .pipeline_ltx2 import LTX2Pipeline
+ from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
+ from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
+ from .vocoder import LTX2Vocoder
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py
new file mode 100644
index 000000000000..22ca42d37902
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/connectors.py
@@ -0,0 +1,326 @@
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...models.attention import FeedForward
+from ...models.modeling_utils import ModelMixin
+from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
+
+
+class LTX2RotaryPosEmbed1d(nn.Module):
+ """
+ 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ base_seq_len: int = 4096,
+ theta: float = 10000.0,
+ double_precision: bool = True,
+ rope_type: str = "interleaved",
+ num_attention_heads: int = 32,
+ ):
+ super().__init__()
+ if rope_type not in ["interleaved", "split"]:
+ raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
+
+ self.dim = dim
+ self.base_seq_len = base_seq_len
+ self.theta = theta
+ self.double_precision = double_precision
+ self.rope_type = rope_type
+ self.num_attention_heads = num_attention_heads
+
+ def forward(
+ self,
+ batch_size: int,
+ pos: int,
+ device: Union[str, torch.device],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Get 1D position ids
+ grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
+ # Get fractional indices relative to self.base_seq_len
+ grid_1d = grid_1d / self.base_seq_len
+ grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
+
+ # 2. Calculate 1D RoPE frequencies
+ num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
+ freqs_dtype = torch.float64 if self.double_precision else torch.float32
+ pow_indices = torch.pow(
+ self.theta,
+ torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
+ )
+ freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
+
+ # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
+ # (self.dim // 2,).
+ freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
+
+ # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
+ if self.rope_type == "interleaved":
+ cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
+ sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
+
+ if self.dim % num_rope_elems != 0:
+ cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
+ sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
+ cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
+ sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
+
+ elif self.rope_type == "split":
+ expected_freqs = self.dim // 2
+ current_freqs = freqs.shape[-1]
+ pad_size = expected_freqs - current_freqs
+ cos_freq = freqs.cos()
+ sin_freq = freqs.sin()
+
+ if pad_size != 0:
+ cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
+ sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
+
+ cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
+ sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
+
+ # Reshape freqs to be compatible with multi-head attention
+ b = cos_freq.shape[0]
+ t = cos_freq.shape[1]
+
+ cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
+ sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
+
+ cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
+ sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
+
+ return cos_freqs, sin_freqs
+
+
+class LTX2TransformerBlock1d(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ activation_fn: str = "gelu-approximate",
+ eps: float = 1e-6,
+ rope_type: str = "interleaved",
+ ):
+ super().__init__()
+
+ self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
+ self.attn1 = LTX2Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ kv_heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ processor=LTX2AudioVideoAttnProcessor(),
+ rope_type=rope_type,
+ )
+
+ self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
+ self.ff = FeedForward(dim, activation_fn=activation_fn)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
+ hidden_states = hidden_states + attn_hidden_states
+
+ norm_hidden_states = self.norm2(hidden_states)
+ ff_hidden_states = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + ff_hidden_states
+
+ return hidden_states
+
+
+class LTX2ConnectorTransformer1d(nn.Module):
+ """
+ A 1D sequence transformer for modalities such as text.
+
+ In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 128,
+ num_layers: int = 2,
+ num_learnable_registers: int | None = 128,
+ rope_base_seq_len: int = 4096,
+ rope_theta: float = 10000.0,
+ rope_double_precision: bool = True,
+ eps: float = 1e-6,
+ causal_temporal_positioning: bool = False,
+ rope_type: str = "interleaved",
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.causal_temporal_positioning = causal_temporal_positioning
+
+ self.num_learnable_registers = num_learnable_registers
+ self.learnable_registers = None
+ if num_learnable_registers is not None:
+ init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
+ self.learnable_registers = torch.nn.Parameter(init_registers)
+
+ self.rope = LTX2RotaryPosEmbed1d(
+ self.inner_dim,
+ base_seq_len=rope_base_seq_len,
+ theta=rope_theta,
+ double_precision=rope_double_precision,
+ rope_type=rope_type,
+ num_attention_heads=num_attention_heads,
+ )
+
+ self.transformer_blocks = torch.nn.ModuleList(
+ [
+ LTX2TransformerBlock1d(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ rope_type=rope_type,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ attn_mask_binarize_threshold: float = -9000.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # hidden_states shape: [batch_size, seq_len, hidden_dim]
+ # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
+ batch_size, seq_len, _ = hidden_states.shape
+
+ # 1. Replace padding with learned registers, if using
+ if self.learnable_registers is not None:
+ if seq_len % self.num_learnable_registers != 0:
+ raise ValueError(
+ f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
+ f" of learnable registers {self.num_learnable_registers}"
+ )
+
+ num_register_repeats = seq_len // self.num_learnable_registers
+ registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
+
+ binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
+ if binary_attn_mask.ndim == 4:
+ binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
+
+ hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
+ valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
+ pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
+ padded_hidden_states = [
+ F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
+ ]
+ padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
+
+ flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
+ hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
+
+ # Overwrite attention_mask with an all-zeros mask if using registers.
+ attention_mask = torch.zeros_like(attention_mask)
+
+ # 2. Calculate 1D RoPE positional embeddings
+ rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
+
+ # 3. Run 1D transformer blocks
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
+ else:
+ hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
+
+ hidden_states = self.norm_out(hidden_states)
+
+ return hidden_states, attention_mask
+
+
+class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
+ """
+ Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
+ streams.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ caption_channels: int,
+ text_proj_in_factor: int,
+ video_connector_num_attention_heads: int,
+ video_connector_attention_head_dim: int,
+ video_connector_num_layers: int,
+ video_connector_num_learnable_registers: int | None,
+ audio_connector_num_attention_heads: int,
+ audio_connector_attention_head_dim: int,
+ audio_connector_num_layers: int,
+ audio_connector_num_learnable_registers: int | None,
+ connector_rope_base_seq_len: int,
+ rope_theta: float,
+ rope_double_precision: bool,
+ causal_temporal_positioning: bool,
+ rope_type: str = "interleaved",
+ ):
+ super().__init__()
+ self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
+ self.video_connector = LTX2ConnectorTransformer1d(
+ num_attention_heads=video_connector_num_attention_heads,
+ attention_head_dim=video_connector_attention_head_dim,
+ num_layers=video_connector_num_layers,
+ num_learnable_registers=video_connector_num_learnable_registers,
+ rope_base_seq_len=connector_rope_base_seq_len,
+ rope_theta=rope_theta,
+ rope_double_precision=rope_double_precision,
+ causal_temporal_positioning=causal_temporal_positioning,
+ rope_type=rope_type,
+ )
+ self.audio_connector = LTX2ConnectorTransformer1d(
+ num_attention_heads=audio_connector_num_attention_heads,
+ attention_head_dim=audio_connector_attention_head_dim,
+ num_layers=audio_connector_num_layers,
+ num_learnable_registers=audio_connector_num_learnable_registers,
+ rope_base_seq_len=connector_rope_base_seq_len,
+ rope_theta=rope_theta,
+ rope_double_precision=rope_double_precision,
+ causal_temporal_positioning=causal_temporal_positioning,
+ rope_type=rope_type,
+ )
+
+ def forward(
+ self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
+ ):
+ # Convert to additive attention mask, if necessary
+ if not additive_mask:
+ text_dtype = text_encoder_hidden_states.dtype
+ attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
+ attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
+
+ text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
+
+ video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
+
+ attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
+ attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
+ video_text_embedding = video_text_embedding * attn_mask
+ new_attn_mask = attn_mask.squeeze(-1)
+
+ audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
+
+ return video_text_embedding, audio_text_embedding, new_attn_mask
diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py
new file mode 100644
index 000000000000..0bc7a59db228
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/export_utils.py
@@ -0,0 +1,134 @@
+# Copyright 2025 The Lightricks team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from fractions import Fraction
+from typing import Optional
+
+import torch
+
+from ...utils import is_av_available
+
+
+_CAN_USE_AV = is_av_available()
+if _CAN_USE_AV:
+ import av
+else:
+ raise ImportError(
+ "PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
+ )
+
+
+def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
+ """
+ Prepare the audio stream for writing.
+ """
+ audio_stream = container.add_stream("aac", rate=audio_sample_rate)
+ audio_stream.codec_context.sample_rate = audio_sample_rate
+ audio_stream.codec_context.layout = "stereo"
+ audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
+ return audio_stream
+
+
+def _resample_audio(
+ container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
+) -> None:
+ cc = audio_stream.codec_context
+
+ # Use the encoder's format/layout/rate as the *target*
+ target_format = cc.format or "fltp" # AAC → usually fltp
+ target_layout = cc.layout or "stereo"
+ target_rate = cc.sample_rate or frame_in.sample_rate
+
+ audio_resampler = av.audio.resampler.AudioResampler(
+ format=target_format,
+ layout=target_layout,
+ rate=target_rate,
+ )
+
+ audio_next_pts = 0
+ for rframe in audio_resampler.resample(frame_in):
+ if rframe.pts is None:
+ rframe.pts = audio_next_pts
+ audio_next_pts += rframe.samples
+ rframe.sample_rate = frame_in.sample_rate
+ container.mux(audio_stream.encode(rframe))
+
+ # flush audio encoder
+ for packet in audio_stream.encode():
+ container.mux(packet)
+
+
+def _write_audio(
+ container: av.container.Container,
+ audio_stream: av.audio.AudioStream,
+ samples: torch.Tensor,
+ audio_sample_rate: int,
+) -> None:
+ if samples.ndim == 1:
+ samples = samples[:, None]
+
+ if samples.shape[1] != 2 and samples.shape[0] == 2:
+ samples = samples.T
+
+ if samples.shape[1] != 2:
+ raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
+
+ # Convert to int16 packed for ingestion; resampler converts to encoder fmt.
+ if samples.dtype != torch.int16:
+ samples = torch.clip(samples, -1.0, 1.0)
+ samples = (samples * 32767.0).to(torch.int16)
+
+ frame_in = av.AudioFrame.from_ndarray(
+ samples.contiguous().reshape(1, -1).cpu().numpy(),
+ format="s16",
+ layout="stereo",
+ )
+ frame_in.sample_rate = audio_sample_rate
+
+ _resample_audio(container, audio_stream, frame_in)
+
+
+def encode_video(
+ video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
+) -> None:
+ video_np = video.cpu().numpy()
+
+ _, height, width, _ = video_np.shape
+
+ container = av.open(output_path, mode="w")
+ stream = container.add_stream("libx264", rate=int(fps))
+ stream.width = width
+ stream.height = height
+ stream.pix_fmt = "yuv420p"
+
+ if audio is not None:
+ if audio_sample_rate is None:
+ raise ValueError("audio_sample_rate is required when audio is provided")
+
+ audio_stream = _prepare_audio_stream(container, audio_sample_rate)
+
+ for frame_array in video_np:
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
+ for packet in stream.encode(frame):
+ container.mux(packet)
+
+ # Flush encoder
+ for packet in stream.encode():
+ container.mux(packet)
+
+ if audio is not None:
+ _write_audio(container, audio_stream, audio, audio_sample_rate)
+
+ container.close()
diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py
new file mode 100644
index 000000000000..69a9b1d9193f
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py
@@ -0,0 +1,285 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+
+
+RATIONAL_RESAMPLER_SCALE_MAPPING = {
+ 0.75: (3, 4),
+ 1.5: (3, 2),
+ 2.0: (2, 1),
+ 4.0: (4, 1),
+}
+
+
+# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock
+class ResBlock(torch.nn.Module):
+ def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
+ super().__init__()
+ if mid_channels is None:
+ mid_channels = channels
+
+ Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
+
+ self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
+ self.norm1 = torch.nn.GroupNorm(32, mid_channels)
+ self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
+ self.norm2 = torch.nn.GroupNorm(32, channels)
+ self.activation = torch.nn.SiLU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.activation(hidden_states + residual)
+ return hidden_states
+
+
+# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND
+class PixelShuffleND(torch.nn.Module):
+ def __init__(self, dims, upscale_factors=(2, 2, 2)):
+ super().__init__()
+
+ self.dims = dims
+ self.upscale_factors = upscale_factors
+
+ if dims not in [1, 2, 3]:
+ raise ValueError("dims must be 1, 2, or 3")
+
+ def forward(self, x):
+ if self.dims == 3:
+ # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
+ return (
+ x.unflatten(1, (-1, *self.upscale_factors[:3]))
+ .permute(0, 1, 5, 2, 6, 3, 7, 4)
+ .flatten(6, 7)
+ .flatten(4, 5)
+ .flatten(2, 3)
+ )
+ elif self.dims == 2:
+ # spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
+ return (
+ x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
+ )
+ elif self.dims == 1:
+ # temporal: b (c p1) f h w -> b c (f p1) h w
+ return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
+
+
+class BlurDownsample(torch.nn.Module):
+ """
+ Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W.
+ Works for dims=2 or dims=3 (per-frame).
+ """
+
+ def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
+ super().__init__()
+
+ if dims not in (2, 3):
+ raise ValueError(f"`dims` must be either 2 or 3 but is {dims}")
+ if kernel_size < 3 or kernel_size % 2 != 1:
+ raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}")
+
+ self.dims = dims
+ self.stride = stride
+ self.kernel_size = kernel_size
+
+ # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
+ # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
+ # provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
+ # The 2D kernel is constructed as the outer product and normalized.
+ k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
+ k2d = k[:, None] @ k[None, :]
+ k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
+ self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.stride == 1:
+ return x
+
+ if self.dims == 2:
+ c = x.shape[1]
+ weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
+ x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
+ else:
+ # dims == 3: apply per-frame on H,W
+ b, c, f, _, _ = x.shape
+ x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
+
+ weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
+ x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
+
+ h2, w2 = x.shape[-2:]
+ x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
+ return x
+
+
+class SpatialRationalResampler(torch.nn.Module):
+ """
+ Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample
+ by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the
+ input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the
+ (integer) denominator.
+ """
+
+ def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
+ super().__init__()
+ self.scale = float(scale)
+ num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None)
+ if num_denom is None:
+ raise ValueError(
+ f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}"
+ )
+ self.num, self.den = num_denom
+
+ self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
+ self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
+ self.blur_down = BlurDownsample(dims=2, stride=self.den)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Expected x shape: [B * F, C, H, W]
+ # b, _, f, h, w = x.shape
+ # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
+ x = self.conv(x)
+ x = self.pixel_shuffle(x)
+ x = self.blur_down(x)
+ # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W]
+ return x
+
+
+class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
+ """
+ Model to spatially upsample VAE latents.
+
+ Args:
+ in_channels (`int`, defaults to `128`):
+ Number of channels in the input latent
+ mid_channels (`int`, defaults to `512`):
+ Number of channels in the middle layers
+ num_blocks_per_stage (`int`, defaults to `4`):
+ Number of ResBlocks to use in each stage (pre/post upsampling)
+ dims (`int`, defaults to `3`):
+ Number of dimensions for convolutions (2 or 3)
+ spatial_upsample (`bool`, defaults to `True`):
+ Whether to spatially upsample the latent
+ temporal_upsample (`bool`, defaults to `False`):
+ Whether to temporally upsample the latent
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 128,
+ mid_channels: int = 1024,
+ num_blocks_per_stage: int = 4,
+ dims: int = 3,
+ spatial_upsample: bool = True,
+ temporal_upsample: bool = False,
+ rational_spatial_scale: Optional[float] = 2.0,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.mid_channels = mid_channels
+ self.num_blocks_per_stage = num_blocks_per_stage
+ self.dims = dims
+ self.spatial_upsample = spatial_upsample
+ self.temporal_upsample = temporal_upsample
+
+ ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
+
+ self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
+ self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
+ self.initial_activation = torch.nn.SiLU()
+
+ self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
+
+ if spatial_upsample and temporal_upsample:
+ self.upsampler = torch.nn.Sequential(
+ torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
+ PixelShuffleND(3),
+ )
+ elif spatial_upsample:
+ if rational_spatial_scale is not None:
+ self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
+ else:
+ self.upsampler = torch.nn.Sequential(
+ torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
+ PixelShuffleND(2),
+ )
+ elif temporal_upsample:
+ self.upsampler = torch.nn.Sequential(
+ torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
+ PixelShuffleND(1),
+ )
+ else:
+ raise ValueError("Either spatial_upsample or temporal_upsample must be True")
+
+ self.post_upsample_res_blocks = torch.nn.ModuleList(
+ [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
+ )
+
+ self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ if self.dims == 2:
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.initial_conv(hidden_states)
+ hidden_states = self.initial_norm(hidden_states)
+ hidden_states = self.initial_activation(hidden_states)
+
+ for block in self.res_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.upsampler(hidden_states)
+
+ for block in self.post_upsample_res_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.final_conv(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ else:
+ hidden_states = self.initial_conv(hidden_states)
+ hidden_states = self.initial_norm(hidden_states)
+ hidden_states = self.initial_activation(hidden_states)
+
+ for block in self.res_blocks:
+ hidden_states = block(hidden_states)
+
+ if self.temporal_upsample:
+ hidden_states = self.upsampler(hidden_states)
+ hidden_states = hidden_states[:, :, 1:, :, :]
+ else:
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.upsampler(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+
+ for block in self.post_upsample_res_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.final_conv(hidden_states)
+
+ return hidden_states
diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py
new file mode 100644
index 000000000000..9cf847926347
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py
@@ -0,0 +1,1141 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
+from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
+from ...models.transformers import LTX2VideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .connectors import LTX2TextConnectors
+from .pipeline_output import LTX2PipelineOutput
+from .vocoder import LTX2Vocoder
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import LTX2Pipeline
+ >>> from diffusers.pipelines.ltx2.export_utils import encode_video
+
+ >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> frame_rate = 24.0
+ >>> video, audio = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=768,
+ ... height=512,
+ ... num_frames=121,
+ ... frame_rate=frame_rate,
+ ... num_inference_steps=40,
+ ... guidance_scale=4.0,
+ ... output_type="np",
+ ... return_dict=False,
+ ... )
+ >>> video = (video * 255).round().astype("uint8")
+ >>> video = torch.from_numpy(video)
+
+ >>> encode_video(
+ ... video[0],
+ ... fps=frame_rate,
+ ... audio=audio[0].float().cpu(),
+ ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
+ ... output_path="video.mp4",
+ ... )
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ Args:
+ transformer ([`LTXVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLLTXVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ connectors ([`LTX2TextConnectors`]):
+ Text connector stack used to adapt text encoder hidden states for the video and audio branches.
+ """
+
+ model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTX2Video,
+ audio_vae: AutoencoderKLLTX2Audio,
+ text_encoder: Gemma3ForConditionalGeneration,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ connectors: LTX2TextConnectors,
+ transformer: LTX2VideoTransformer3DModel,
+ vocoder: LTX2Vocoder,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ audio_vae=audio_vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ connectors=connectors,
+ transformer=transformer,
+ vocoder=vocoder,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ # TODO: check whether the MEL compression ratio logic here is corrct
+ self.audio_vae_mel_compression_ratio = (
+ self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
+ )
+ self.audio_vae_temporal_compression_ratio = (
+ self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
+ )
+
+ self.audio_sampling_rate = (
+ self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000
+ )
+ self.audio_hop_length = (
+ self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
+ )
+
+ @staticmethod
+ def _pack_text_embeds(
+ text_hidden_states: torch.Tensor,
+ sequence_lengths: torch.Tensor,
+ device: Union[str, torch.device],
+ padding_side: str = "left",
+ scale_factor: int = 8,
+ eps: float = 1e-6,
+ ) -> torch.Tensor:
+ """
+ Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
+ per-layer in a masked fashion (only over non-padded positions).
+
+ Args:
+ text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
+ Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
+ sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
+ The number of valid (non-padded) tokens for each batch instance.
+ device: (`str` or `torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ padding_side: (`str`, *optional*, defaults to `"left"`):
+ Whether the text tokenizer performs padding on the `"left"` or `"right"`.
+ scale_factor (`int`, *optional*, defaults to `8`):
+ Scaling factor to multiply the normalized hidden states by.
+ eps (`float`, *optional*, defaults to `1e-6`):
+ A small positive value for numerical stability when performing normalization.
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
+ Normed and flattened text encoder hidden states.
+ """
+ batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
+ original_dtype = text_hidden_states.dtype
+
+ # Create padding mask
+ token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
+ if padding_side == "right":
+ # For right padding, valid tokens are from 0 to sequence_length-1
+ mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
+ elif padding_side == "left":
+ # For left padding, valid tokens are from (T - sequence_length) to T-1
+ start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
+ mask = token_indices >= start_indices # [B, T]
+ else:
+ raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
+ mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
+
+ # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
+ masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
+ num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
+ masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
+
+ # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
+ x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
+ x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
+
+ # Normalization
+ normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
+ normalized_hidden_states = normalized_hidden_states * scale_factor
+
+ # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
+ normalized_hidden_states = normalized_hidden_states.flatten(2)
+ mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
+ normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
+ normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
+ return normalized_hidden_states
+
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 1024,
+ scale_factor: int = 8,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`str` or `torch.device`):
+ torch device to place the resulting embeddings on
+ dtype: (`torch.dtype`):
+ torch dtype to cast the prompt embeds to
+ max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if getattr(self, "tokenizer", None) is not None:
+ # Gemma expects left padding for chat-style prompts
+ self.tokenizer.padding_side = "left"
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ prompt = [p.strip() for p in prompt]
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ text_input_ids = text_input_ids.to(device)
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ text_encoder_outputs = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ )
+ text_encoder_hidden_states = text_encoder_outputs.hidden_states
+ text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
+ sequence_lengths = prompt_attention_mask.sum(dim=-1)
+
+ prompt_embeds = self._pack_text_embeds(
+ text_encoder_hidden_states,
+ sequence_lengths,
+ device=device,
+ padding_side=self.tokenizer.padding_side,
+ scale_factor=scale_factor,
+ )
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ scale_factor: int = 8,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ scale_factor=scale_factor,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ scale_factor=scale_factor,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ @staticmethod
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ @staticmethod
+ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
+ latents_mean = latents_mean.to(latents.device, latents.dtype)
+ latents_std = latents_std.to(latents.device, latents.dtype)
+ return (latents * latents_std) + latents_mean
+
+ @staticmethod
+ def _pack_audio_latents(
+ latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
+ ) -> torch.Tensor:
+ # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins
+ if patch_size is not None and patch_size_t is not None:
+ # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor).
+ # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size.
+ batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
+ post_patch_latent_length = latent_length / patch_size_t
+ post_patch_mel_bins = latent_mel_bins / patch_size
+ latents = latents.reshape(
+ batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
+ )
+ latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
+ else:
+ # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel)
+ # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1.
+ latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M]
+ return latents
+
+ @staticmethod
+ def _unpack_audio_latents(
+ latents: torch.Tensor,
+ latent_length: int,
+ num_mel_bins: int,
+ patch_size: Optional[int] = None,
+ patch_size_t: Optional[int] = None,
+ ) -> torch.Tensor:
+ # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M],
+ # where L is the latent audio length and M is the number of mel bins.
+ if patch_size is not None and patch_size_t is not None:
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size)
+ latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
+ else:
+ # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1.
+ latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 128,
+ height: int = 512,
+ width: int = 768,
+ num_frames: int = 121,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ height = height // self.vae_spatial_compression_ratio
+ width = width // self.vae_spatial_compression_ratio
+ num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ return latents
+
+ def prepare_audio_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 8,
+ num_mel_bins: int = 64,
+ num_frames: int = 121,
+ frame_rate: float = 25.0,
+ sampling_rate: int = 16000,
+ hop_length: int = 160,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ duration_s = num_frames / frame_rate
+ latents_per_second = (
+ float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
+ )
+ latent_length = round(duration_s * latents_per_second)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_length
+
+ # TODO: confirm whether this logic is correct
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
+
+ shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_audio_latents(latents)
+ return latents, latent_length
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 768,
+ num_frames: int = 121,
+ frame_rate: float = 24.0,
+ num_inference_steps: int = 40,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ guidance_rescale: float = 0.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ audio_latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to `512`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to `768`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, *optional*, defaults to `121`):
+ The number of video frames to generate
+ frame_rate (`float`, *optional*, defaults to `24.0`):
+ The frames per second (FPS) of the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 40):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to `4.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ audio_latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `1024`):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
+ connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
+ prompt_embeds, additive_attention_mask, additive_mask=True
+ )
+
+ # 4. Prepare latent variables
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ video_sequence_length = latent_num_frames * latent_height * latent_width
+
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
+
+ num_channels_latents_audio = (
+ self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
+ )
+ audio_latents, audio_num_frames = self.prepare_audio_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents=num_channels_latents_audio,
+ num_mel_bins=num_mel_bins,
+ num_frames=num_frames, # Video frames, audio frames will be calculated from this
+ frame_rate=frame_rate,
+ sampling_rate=self.audio_sampling_rate,
+ hop_length=self.audio_hop_length,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=audio_latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 1024),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.95),
+ self.scheduler.config.get("max_shift", 2.05),
+ )
+ # For now, duplicate the scheduler for use with the audio latents
+ audio_scheduler = copy.deepcopy(self.scheduler)
+ _, _ = retrieve_timesteps(
+ audio_scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare micro-conditions
+ rope_interpolation_scale = (
+ self.vae_temporal_compression_ratio / frame_rate,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+ # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
+ video_coords = self.transformer.rope.prepare_video_coords(
+ latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
+ )
+ audio_coords = self.transformer.audio_rope.prepare_audio_coords(
+ audio_latents.shape[0], audio_num_frames, audio_latents.device
+ )
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+ audio_latent_model_input = (
+ torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
+ )
+ audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred_video, noise_pred_audio = self.transformer(
+ hidden_states=latent_model_input,
+ audio_hidden_states=audio_latent_model_input,
+ encoder_hidden_states=connector_prompt_embeds,
+ audio_encoder_hidden_states=connector_audio_prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=connector_attention_mask,
+ audio_encoder_attention_mask=connector_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ fps=frame_rate,
+ audio_num_frames=audio_num_frames,
+ video_coords=video_coords,
+ audio_coords=audio_coords,
+ # rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )
+ noise_pred_video = noise_pred_video.float()
+ noise_pred_audio = noise_pred_audio.float()
+
+ if self.do_classifier_free_guidance:
+ noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
+ noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
+ noise_pred_video_text - noise_pred_video_uncond
+ )
+
+ noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
+ noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
+ noise_pred_audio_text - noise_pred_audio_uncond
+ )
+
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
+ noise_pred_video = rescale_noise_cfg(
+ noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
+ )
+ noise_pred_audio = rescale_noise_cfg(
+ noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
+ # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in
+ # the step method (such as _step_index)
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+
+ audio_latents = self._denormalize_audio_latents(
+ audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
+ )
+ audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
+
+ if output_type == "latent":
+ video = latents
+ audio = audio_latents
+ else:
+ latents = latents.to(prompt_embeds.dtype)
+
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ latents = latents.to(self.vae.dtype)
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ audio_latents = audio_latents.to(self.audio_vae.dtype)
+ generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
+ audio = self.vocoder(generated_mel_spectrograms)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video, audio)
+
+ return LTX2PipelineOutput(frames=video, audio=audio)
diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
new file mode 100644
index 000000000000..b1711e283191
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
@@ -0,0 +1,1238 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
+from ...models.transformers import LTX2VideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .connectors import LTX2TextConnectors
+from .pipeline_output import LTX2PipelineOutput
+from .vocoder import LTX2Vocoder
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import LTX2Pipeline
+ >>> from diffusers.pipelines.ltx2.export_utils import encode_video
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
+ ... )
+ >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> frame_rate = 24.0
+ >>> video = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=768,
+ ... height=512,
+ ... num_frames=121,
+ ... frame_rate=frame_rate,
+ ... num_inference_steps=40,
+ ... guidance_scale=4.0,
+ ... output_type="np",
+ ... return_dict=False,
+ ... )
+ >>> video = (video * 255).round().astype("uint8")
+ >>> video = torch.from_numpy(video)
+
+ >>> encode_video(
+ ... video[0],
+ ... fps=frame_rate,
+ ... audio=audio[0].float().cpu(),
+ ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
+ ... output_path="video.mp4",
+ ... )
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ TODO
+ """
+
+ model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTX2Video,
+ audio_vae: AutoencoderKLLTX2Audio,
+ text_encoder: Gemma3ForConditionalGeneration,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ connectors: LTX2TextConnectors,
+ transformer: LTX2VideoTransformer3DModel,
+ vocoder: LTX2Vocoder,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ audio_vae=audio_vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ connectors=connectors,
+ transformer=transformer,
+ vocoder=vocoder,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ # TODO: check whether the MEL compression ratio logic here is corrct
+ self.audio_vae_mel_compression_ratio = (
+ self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
+ )
+ self.audio_vae_temporal_compression_ratio = (
+ self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
+ )
+
+ self.audio_sampling_rate = (
+ self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000
+ )
+ self.audio_hop_length = (
+ self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear")
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
+ def _pack_text_embeds(
+ text_hidden_states: torch.Tensor,
+ sequence_lengths: torch.Tensor,
+ device: Union[str, torch.device],
+ padding_side: str = "left",
+ scale_factor: int = 8,
+ eps: float = 1e-6,
+ ) -> torch.Tensor:
+ """
+ Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
+ per-layer in a masked fashion (only over non-padded positions).
+
+ Args:
+ text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
+ Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
+ sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
+ The number of valid (non-padded) tokens for each batch instance.
+ device: (`str` or `torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ padding_side: (`str`, *optional*, defaults to `"left"`):
+ Whether the text tokenizer performs padding on the `"left"` or `"right"`.
+ scale_factor (`int`, *optional*, defaults to `8`):
+ Scaling factor to multiply the normalized hidden states by.
+ eps (`float`, *optional*, defaults to `1e-6`):
+ A small positive value for numerical stability when performing normalization.
+
+ Returns:
+ `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
+ Normed and flattened text encoder hidden states.
+ """
+ batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
+ original_dtype = text_hidden_states.dtype
+
+ # Create padding mask
+ token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
+ if padding_side == "right":
+ # For right padding, valid tokens are from 0 to sequence_length-1
+ mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
+ elif padding_side == "left":
+ # For left padding, valid tokens are from (T - sequence_length) to T-1
+ start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
+ mask = token_indices >= start_indices # [B, T]
+ else:
+ raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
+ mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
+
+ # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
+ masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
+ num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
+ masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
+
+ # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
+ x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
+ x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
+
+ # Normalization
+ normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
+ normalized_hidden_states = normalized_hidden_states * scale_factor
+
+ # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
+ normalized_hidden_states = normalized_hidden_states.flatten(2)
+ mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
+ normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
+ normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
+ return normalized_hidden_states
+
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 1024,
+ scale_factor: int = 8,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`str` or `torch.device`):
+ torch device to place the resulting embeddings on
+ dtype: (`torch.dtype`):
+ torch dtype to cast the prompt embeds to
+ max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if getattr(self, "tokenizer", None) is not None:
+ # Gemma expects left padding for chat-style prompts
+ self.tokenizer.padding_side = "left"
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ prompt = [p.strip() for p in prompt]
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ text_input_ids = text_input_ids.to(device)
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ text_encoder_outputs = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ )
+ text_encoder_hidden_states = text_encoder_outputs.hidden_states
+ text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
+ sequence_lengths = prompt_attention_mask.sum(dim=-1)
+
+ prompt_embeds = self._pack_text_embeds(
+ text_encoder_hidden_states,
+ sequence_lengths,
+ device=device,
+ padding_side=self.tokenizer.padding_side,
+ scale_factor=scale_factor,
+ )
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ scale_factor: int = 8,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ scale_factor=scale_factor,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ scale_factor=scale_factor,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents
+ def _pack_audio_latents(
+ latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
+ ) -> torch.Tensor:
+ # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins
+ if patch_size is not None and patch_size_t is not None:
+ # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor).
+ # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size.
+ batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
+ post_patch_latent_length = latent_length / patch_size_t
+ post_patch_mel_bins = latent_mel_bins / patch_size
+ latents = latents.reshape(
+ batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
+ )
+ latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
+ else:
+ # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel)
+ # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1.
+ latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M]
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents
+ def _unpack_audio_latents(
+ latents: torch.Tensor,
+ latent_length: int,
+ num_mel_bins: int,
+ patch_size: Optional[int] = None,
+ patch_size_t: Optional[int] = None,
+ ) -> torch.Tensor:
+ # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M],
+ # where L is the latent audio length and M is the number of mel bins.
+ if patch_size is not None and patch_size_t is not None:
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size)
+ latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
+ else:
+ # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1.
+ latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents
+ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
+ latents_mean = latents_mean.to(latents.device, latents.dtype)
+ latents_std = latents_std.to(latents.device, latents.dtype)
+ return (latents * latents_std) + latents_mean
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 128,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ height = height // self.vae_spatial_compression_ratio
+ width = width // self.vae_spatial_compression_ratio
+ num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+ mask_shape = (batch_size, 1, num_frames, height, width)
+
+ if latents is not None:
+ conditioning_mask = latents.new_zeros(mask_shape)
+ conditioning_mask[:, :, 0] = 1.0
+ conditioning_mask = self._pack_latents(
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ ).squeeze(-1)
+ if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
+ raise ValueError(
+ f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
+ )
+ return latents.to(device=device, dtype=dtype), conditioning_mask
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax")
+ for i in range(batch_size)
+ ]
+ else:
+ init_latents = [
+ retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image
+ ]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+ init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
+ init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
+
+ # First condition is image latents and those should be kept clean.
+ conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
+ conditioning_mask[:, :, 0] = 1.0
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # Interpolation.
+ latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
+
+ conditioning_mask = self._pack_latents(
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ ).squeeze(-1)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+
+ return latents, conditioning_mask
+
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents
+ def prepare_audio_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 8,
+ num_mel_bins: int = 64,
+ num_frames: int = 121,
+ frame_rate: float = 25.0,
+ sampling_rate: int = 16000,
+ hop_length: int = 160,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ duration_s = num_frames / frame_rate
+ latents_per_second = (
+ float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
+ )
+ latent_length = round(duration_s * latents_per_second)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_length
+
+ # TODO: confirm whether this logic is correct
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
+
+ shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_audio_latents(latents)
+ return latents, latent_length
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 768,
+ num_frames: int = 121,
+ frame_rate: float = 24.0,
+ num_inference_steps: int = 40,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ guidance_rescale: float = 0.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ audio_latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to `512`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to `768`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, *optional*, defaults to `121`):
+ The number of video frames to generate
+ frame_rate (`float`, *optional*, defaults to `24.0`):
+ The frames per second (FPS) of the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 40):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to `4.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ audio_latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `1024`):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
+ connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
+ prompt_embeds, additive_attention_mask, additive_mask=True
+ )
+
+ # 4. Prepare latent variables
+ if latents is None:
+ image = self.video_processor.preprocess(image, height=height, width=width)
+ image = image.to(device=device, dtype=prompt_embeds.dtype)
+
+ num_channels_latents = self.transformer.config.in_channels
+ latents, conditioning_mask = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+ if self.do_classifier_free_guidance:
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
+
+ num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
+
+ num_channels_latents_audio = (
+ self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
+ )
+ audio_latents, audio_num_frames = self.prepare_audio_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents=num_channels_latents_audio,
+ num_mel_bins=num_mel_bins,
+ num_frames=num_frames, # Video frames, audio frames will be calculated from this
+ frame_rate=frame_rate,
+ sampling_rate=self.audio_sampling_rate,
+ hop_length=self.audio_hop_length,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=audio_latents,
+ )
+
+ # 5. Prepare timesteps
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ video_sequence_length = latent_num_frames * latent_height * latent_width
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 1024),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.95),
+ self.scheduler.config.get("max_shift", 2.05),
+ )
+
+ # For now, duplicate the scheduler for use with the audio latents
+ audio_scheduler = copy.deepcopy(self.scheduler)
+ _, _ = retrieve_timesteps(
+ audio_scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare micro-conditions
+ rope_interpolation_scale = (
+ self.vae_temporal_compression_ratio / frame_rate,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+ # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
+ video_coords = self.transformer.rope.prepare_video_coords(
+ latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
+ )
+ audio_coords = self.transformer.audio_rope.prepare_audio_coords(
+ audio_latents.shape[0], audio_num_frames, audio_latents.device
+ )
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+ audio_latent_model_input = (
+ torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
+ )
+ audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
+
+ timestep = t.expand(latent_model_input.shape[0])
+ video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
+
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred_video, noise_pred_audio = self.transformer(
+ hidden_states=latent_model_input,
+ audio_hidden_states=audio_latent_model_input,
+ encoder_hidden_states=connector_prompt_embeds,
+ audio_encoder_hidden_states=connector_audio_prompt_embeds,
+ timestep=video_timestep,
+ audio_timestep=timestep,
+ encoder_attention_mask=connector_attention_mask,
+ audio_encoder_attention_mask=connector_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ fps=frame_rate,
+ audio_num_frames=audio_num_frames,
+ video_coords=video_coords,
+ audio_coords=audio_coords,
+ # rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )
+ noise_pred_video = noise_pred_video.float()
+ noise_pred_audio = noise_pred_audio.float()
+
+ if self.do_classifier_free_guidance:
+ noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
+ noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
+ noise_pred_video_text - noise_pred_video_uncond
+ )
+
+ noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
+ noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
+ noise_pred_audio_text - noise_pred_audio_uncond
+ )
+
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
+ noise_pred_video = rescale_noise_cfg(
+ noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
+ )
+ noise_pred_audio = rescale_noise_cfg(
+ noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ noise_pred_video = self._unpack_latents(
+ noise_pred_video,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+
+ noise_pred_video = noise_pred_video[:, :, 1:]
+ noise_latents = latents[:, :, 1:]
+ pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0]
+
+ latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+
+ # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in
+ # the step method (such as _step_index)
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+
+ audio_latents = self._denormalize_audio_latents(
+ audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
+ )
+ audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
+
+ if output_type == "latent":
+ video = latents
+ audio = audio_latents
+ else:
+ latents = latents.to(prompt_embeds.dtype)
+
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ latents = latents.to(self.vae.dtype)
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ audio_latents = audio_latents.to(self.audio_vae.dtype)
+ generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
+ audio = self.vocoder(generated_mel_spectrograms)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video, audio)
+
+ return LTX2PipelineOutput(frames=video, audio=audio)
diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
new file mode 100644
index 000000000000..a44c40b0430f
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
@@ -0,0 +1,442 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import torch
+
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLLTX2Video
+from ...utils import get_logger, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..ltx.pipeline_output import LTXPipelineOutput
+from ..pipeline_utils import DiffusionPipeline
+from .latent_upsampler import LTX2LatentUpsamplerModel
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
+ >>> from diffusers.pipelines.ltx2.export_utils import encode_video
+ >>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
+ ... )
+ >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> frame_rate = 24.0
+ >>> video, audio = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=768,
+ ... height=512,
+ ... num_frames=121,
+ ... frame_rate=frame_rate,
+ ... num_inference_steps=40,
+ ... guidance_scale=4.0,
+ ... output_type="pil",
+ ... return_dict=False,
+ ... )
+
+ >>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
+ ... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16
+ ... )
+ >>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
+ >>> upsample_pipe.vae.enable_tiling()
+ >>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16)
+
+ >>> video = upsample_pipe(
+ ... video=video,
+ ... width=768,
+ ... height=512,
+ ... output_type="np",
+ ... return_dict=False,
+ ... )[0]
+ >>> video = (video * 255).round().astype("uint8")
+ >>> video = torch.from_numpy(video)
+
+ >>> encode_video(
+ ... video[0],
+ ... fps=frame_rate,
+ ... audio=audio[0].float().cpu(),
+ ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
+ ... output_path="video.mp4",
+ ... )
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LTX2LatentUpsamplePipeline(DiffusionPipeline):
+ model_cpu_offload_seq = "vae->latent_upsampler"
+
+ def __init__(
+ self,
+ vae: AutoencoderKLLTX2Video,
+ latent_upsampler: LTX2LatentUpsamplerModel,
+ ) -> None:
+ super().__init__()
+
+ self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_frames: int = 121,
+ height: int = 512,
+ width: int = 768,
+ spatial_patch_size: int = 1,
+ temporal_patch_size: int = 1,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ if latents.ndim == 3:
+ # Convert token seq [B, S, D] to latent video [B, C, F, H, W]
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ latents = self._unpack_latents(
+ latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
+ )
+ return latents.to(device=device, dtype=dtype)
+
+ video = video.to(device=device, dtype=self.vae.dtype)
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
+ ]
+ else:
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+ # NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
+ # init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
+ return init_latents
+
+ def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
+ """
+ Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
+ tensor.
+
+ Args:
+ latent (`torch.Tensor`):
+ Input latents to normalize
+ reference_latents (`torch.Tensor`):
+ The reference latents providing style statistics.
+ factor (`float`):
+ Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
+
+ Returns:
+ torch.Tensor: The transformed latent tensor
+ """
+ result = latents.clone()
+
+ for i in range(latents.size(0)):
+ for c in range(latents.size(1)):
+ r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
+ i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
+
+ result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
+
+ result = torch.lerp(latents, result, factor)
+ return result
+
+ def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
+ """
+ Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
+ smooth way using a sigmoid-based compression.
+
+ This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
+ when controlling dynamic behavior with a `compression` factor.
+
+ Args:
+ latents : torch.Tensor
+ Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
+ compression : float
+ Compression strength in the range [0, 1].
+ - 0.0: No tone-mapping (identity transform)
+ - 1.0: Full compression effect
+
+ Returns:
+ torch.Tensor
+ The tone-mapped latent tensor of the same shape as input.
+ """
+ # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
+ scale_factor = compression * 0.75
+ abs_latents = torch.abs(latents)
+
+ # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
+ # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
+ sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
+ scales = 1.0 - 0.8 * scale_factor * sigmoid_term
+
+ filtered = latents * scales
+ return filtered
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
+ if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` can be provided.")
+ if video is None and latents is None:
+ raise ValueError("One of `video` or `latents` has to be provided.")
+
+ if not (0 <= tone_map_compression_ratio <= 1):
+ raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: Optional[List[PipelineImageInput]] = None,
+ height: int = 512,
+ width: int = 768,
+ num_frames: int = 121,
+ spatial_patch_size: int = 1,
+ temporal_patch_size: int = 1,
+ latents: Optional[torch.Tensor] = None,
+ latents_normalized: bool = False,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ adain_factor: float = 0.0,
+ tone_map_compression_ratio: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ video (`List[PipelineImageInput]`, *optional*)
+ The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
+ supplied.
+ height (`int`, *optional*, defaults to `512`):
+ The height in pixels of the input video (not the generated video, which will have a larger resolution).
+ width (`int`, *optional*, defaults to `768`):
+ The width in pixels of the input video (not the generated video, which will have a larger resolution).
+ num_frames (`int`, *optional*, defaults to `121`):
+ The number of frames in the input video.
+ spatial_patch_size (`int`, *optional*, defaults to `1`):
+ The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
+ temporal_patch_size (`int`, *optional*, defaults to `1`):
+ The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
+ necessary.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
+ patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
+ latent_channels, latent_frames, latent_height, latent_width)`.
+ latents_normalized (`bool`, *optional*, defaults to `False`)
+ If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
+ `True`, the `latents` will be denormalized before being supplied to the latent upsampler.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ adain_factor (`float`, *optional*, defaults to `0.0`):
+ Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
+ Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
+ tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
+ The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
+ This is useful for regularizing high-variance latents or for conditioning outputs during generation.
+ Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
+ the full compression effect.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is the upsampled video.
+ """
+
+ self.check_inputs(
+ video=video,
+ height=height,
+ width=width,
+ latents=latents,
+ tone_map_compression_ratio=tone_map_compression_ratio,
+ )
+
+ if video is not None:
+ # Batched video input is not yet tested/supported. TODO: take a look later
+ batch_size = 1
+ else:
+ batch_size = latents.shape[0]
+ device = self._execution_device
+
+ if video is not None:
+ num_frames = len(video)
+ if num_frames % self.vae_temporal_compression_ratio != 1:
+ num_frames = (
+ num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
+ )
+ video = video[:num_frames]
+ logger.warning(
+ f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
+ )
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ video = video.to(device=device, dtype=torch.float32)
+
+ latents_supplied = latents is not None
+ latents = self.prepare_latents(
+ video=video,
+ batch_size=batch_size,
+ num_frames=num_frames,
+ height=height,
+ width=width,
+ spatial_patch_size=spatial_patch_size,
+ temporal_patch_size=temporal_patch_size,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ if latents_supplied and latents_normalized:
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ latents = latents.to(self.latent_upsampler.dtype)
+ latents_upsampled = self.latent_upsampler(latents)
+
+ if adain_factor > 0.0:
+ latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
+ else:
+ latents = latents_upsampled
+
+ if tone_map_compression_ratio > 0.0:
+ latents = self.tone_map_latents(latents, tone_map_compression_ratio)
+
+ if output_type == "latent":
+ latents = self._normalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ video = latents
+ else:
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/ltx2/pipeline_output.py b/src/diffusers/pipelines/ltx2/pipeline_output.py
new file mode 100644
index 000000000000..eacd571125b0
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/pipeline_output.py
@@ -0,0 +1,23 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class LTX2PipelineOutput(BaseOutput):
+ r"""
+ Output class for LTX pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ audio (`torch.Tensor`, `np.ndarray`):
+ TODO
+ """
+
+ frames: torch.Tensor
+ audio: torch.Tensor
diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py
new file mode 100644
index 000000000000..217c68103e39
--- /dev/null
+++ b/src/diffusers/pipelines/ltx2/vocoder.py
@@ -0,0 +1,159 @@
+import math
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+
+
+class ResBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilations: Tuple[int, ...] = (1, 3, 5),
+ leaky_relu_negative_slope: float = 0.1,
+ padding_mode: str = "same",
+ ):
+ super().__init__()
+ self.dilations = dilations
+ self.negative_slope = leaky_relu_negative_slope
+
+ self.convs1 = nn.ModuleList(
+ [
+ nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
+ for dilation in dilations
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
+ for _ in range(len(dilations))
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for conv1, conv2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, negative_slope=self.negative_slope)
+ xt = conv1(xt)
+ xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
+ xt = conv2(xt)
+ x = x + xt
+ return x
+
+
+class LTX2Vocoder(ModelMixin, ConfigMixin):
+ r"""
+ LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 128,
+ hidden_channels: int = 1024,
+ out_channels: int = 2,
+ upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
+ upsample_factors: List[int] = [6, 5, 2, 2, 2],
+ resnet_kernel_sizes: List[int] = [3, 7, 11],
+ resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ leaky_relu_negative_slope: float = 0.1,
+ output_sampling_rate: int = 24000,
+ ):
+ super().__init__()
+ self.num_upsample_layers = len(upsample_kernel_sizes)
+ self.resnets_per_upsample = len(resnet_kernel_sizes)
+ self.out_channels = out_channels
+ self.total_upsample_factor = math.prod(upsample_factors)
+ self.negative_slope = leaky_relu_negative_slope
+
+ if self.num_upsample_layers != len(upsample_factors):
+ raise ValueError(
+ f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
+ f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
+ )
+
+ if self.resnets_per_upsample != len(resnet_dilations):
+ raise ValueError(
+ f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
+ f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
+ )
+
+ self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
+
+ self.upsamplers = nn.ModuleList()
+ self.resnets = nn.ModuleList()
+ input_channels = hidden_channels
+ for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
+ output_channels = input_channels // 2
+ self.upsamplers.append(
+ nn.ConvTranspose1d(
+ input_channels, # hidden_channels // (2 ** i)
+ output_channels, # hidden_channels // (2 ** (i + 1))
+ kernel_size,
+ stride=stride,
+ padding=(kernel_size - stride) // 2,
+ )
+ )
+
+ for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
+ self.resnets.append(
+ ResBlock(
+ output_channels,
+ kernel_size,
+ dilations=dilations,
+ leaky_relu_negative_slope=leaky_relu_negative_slope,
+ )
+ )
+ input_channels = output_channels
+
+ self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
+
+ def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
+ r"""
+ Forward pass of the vocoder.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
+ is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
+ `True`.
+ time_last (`bool`, *optional*, defaults to `False`):
+ Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
+
+ Returns:
+ `torch.Tensor`:
+ Audio waveform tensor of shape (batch_size, out_channels, audio_length)
+ """
+
+ # Ensure that the time/frame dimension is last
+ if not time_last:
+ hidden_states = hidden_states.transpose(2, 3)
+ # Combine channels and frequency (mel bins) dimensions
+ hidden_states = hidden_states.flatten(1, 2)
+
+ hidden_states = self.conv_in(hidden_states)
+
+ for i in range(self.num_upsample_layers):
+ hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
+ hidden_states = self.upsamplers[i](hidden_states)
+
+ # Run all resnets in parallel on hidden_states
+ start = i * self.resnets_per_upsample
+ end = (i + 1) * self.resnets_per_upsample
+ resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
+
+ hidden_states = torch.mean(resnet_outputs, dim=0)
+
+ # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
+ # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
+ hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
+ hidden_states = self.conv_out(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+
+ return hidden_states
diff --git a/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
index 69f69d5768a8..8065a17b7889 100644
--- a/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
+++ b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
@@ -14,7 +14,7 @@
# limitations under the License.
#
# Modifications by Decart AI Team:
-# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
+# - Based on pipeline_wan.py, but with supports receiving a condition video appended to the channel dimension.
import html
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index b59c265646cd..f4711cf9d9d8 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -799,7 +799,13 @@ def __call__(
prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
index 937803edbcbc..8151b29b25fd 100644
--- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
+++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
@@ -704,10 +704,14 @@ def __call__(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
mu=mu,
)
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
index 5874a92c6f2f..19a36c73f9ed 100644
--- a/src/diffusers/pipelines/mochi/pipeline_mochi.py
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -668,10 +668,14 @@ def __call__(
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
timesteps,
sigmas,
)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 090cb46aace4..96c209813f54 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -459,8 +459,12 @@ def __call__(
# 5. Prepare timesteps
sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps]
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas=sigmas
)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
index 1abef014301a..389927aafcbc 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
@@ -1131,8 +1131,12 @@ def __call__(
assert False
# 5. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
index 381352ccc5d4..8b7df89f039c 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
@@ -1329,8 +1329,12 @@ def __call__(
assert False
# 5. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
index df5b3f5c10a5..5a6b8d5e9f37 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
@@ -85,7 +85,6 @@
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
-
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
index 1403be03a620..5b82d546445b 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
@@ -905,8 +905,12 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
index 9031877b5b8d..283862989c71 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
@@ -764,8 +764,12 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
index 9e91ccbe8006..466996889417 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
@@ -856,8 +856,12 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
index ea64f8be2c50..67676fb28798 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
@@ -952,8 +952,12 @@ def __call__(
ip_adapter_image_embeds[i] = image_embeds
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
index 941b675099b9..303a0a2f0b2e 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
@@ -888,7 +888,13 @@ def __call__(
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
index f40dd52fc244..2005c865c22b 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
@@ -951,7 +951,13 @@ def __call__(
image = self.image_processor.preprocess(image, height=height, width=width)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
+ )
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
index 8351112ce409..42b5db0fa762 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
@@ -986,8 +986,12 @@ def __call__(
image = self.image_processor.preprocess(image)
# 5. set timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
index 6b1b294e10f5..cf8c4972762f 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -1094,8 +1094,12 @@ def __call__(
)
# 4. set timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
index a69f06536a55..0613ec23f740 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
@@ -1095,8 +1095,12 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
index 416d9e5677b4..1081993f46e6 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
@@ -1272,8 +1272,12 @@ def __call__(
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index 6be341e07b1a..f6c4982c1c6c 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -1392,8 +1392,12 @@ def __call__(
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index 8868e942ce3d..57d4eaa8f89e 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -758,6 +758,7 @@ def load_sub_model(
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
+ disable_mmap: bool,
quantization_config: Optional[Any] = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
@@ -801,12 +802,6 @@ def load_sub_model(
# add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
loading_kwargs = {}
- if issubclass(class_obj, torch.nn.Module):
- loading_kwargs["torch_dtype"] = torch_dtype
- if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
- loading_kwargs["provider"] = provider
- loading_kwargs["sess_options"] = sess_options
- loading_kwargs["provider_options"] = provider_options
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
@@ -821,6 +816,17 @@ def load_sub_model(
and transformers_version >= version.parse("4.20.0")
)
+ # For transformers models >= 4.56.0, use 'dtype' instead of 'torch_dtype' to avoid deprecation warnings
+ if issubclass(class_obj, torch.nn.Module):
+ if is_transformers_model and transformers_version >= version.parse("4.56.0"):
+ loading_kwargs["dtype"] = torch_dtype
+ else:
+ loading_kwargs["torch_dtype"] = torch_dtype
+ if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
+ loading_kwargs["provider"] = provider
+ loading_kwargs["sess_options"] = sess_options
+ loading_kwargs["provider_options"] = provider_options
+
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
@@ -854,6 +860,9 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
+ if is_diffusers_model:
+ loading_kwargs["disable_mmap"] = disable_mmap
+
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
loading_kwargs.pop("offload_state_dict")
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 392d5fb3feb4..b96305c74131 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -60,6 +60,7 @@
deprecate,
is_accelerate_available,
is_accelerate_version,
+ is_bitsandbytes_version,
is_hpu_available,
is_torch_npu_available,
is_torch_version,
@@ -67,6 +68,7 @@
logging,
numpy_to_pil,
)
+from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
@@ -443,7 +445,10 @@ def module_is_sequentially_offloaded(module):
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
- if is_loaded_in_8bit_bnb:
+ # https://github.com/huggingface/accelerate/pull/3907
+ if is_loaded_in_8bit_bnb and (
+ is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0")
+ ):
return False
return hasattr(module, "_hf_hook") and (
@@ -522,9 +527,10 @@ def module_is_offloaded(module):
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
)
- if is_loaded_in_8bit_bnb and device is not None:
+ if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"):
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
+ "You need to upgrade bitsandbytes to at least 0.48.0"
)
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
@@ -541,6 +547,14 @@ def module_is_offloaded(module):
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
+ # added here https://github.com/huggingface/transformers/pull/43258
+ if (
+ is_loaded_in_8bit_bnb
+ and device is not None
+ and is_transformers_version(">", "4.58.0")
+ and is_bitsandbytes_version(">=", "0.48.0")
+ ):
+ module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
module.to(device, dtype)
@@ -707,6 +721,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
loading `from_flax`.
dduf_file(`str`, *optional*):
Load weights from the specified dduf file.
+ disable_mmap ('bool', *optional*, defaults to 'False'):
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
with `hf > auth login`.
@@ -758,6 +775,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
+ disable_mmap = kwargs.pop("disable_mmap", False)
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
@@ -982,7 +1000,11 @@ def load_module(name, value):
# 7. Load each module in the pipeline
current_device_map = None
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
- for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
+ logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
+ if not is_torch_dist_rank_zero():
+ logging_tqdm_kwargs["disable"] = True
+
+ for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
# 7.1 device_map shenanigans
if final_device_map is not None:
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
@@ -1041,6 +1063,7 @@ def load_module(name, value):
use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
provider_options=provider_options,
+ disable_mmap=disable_mmap,
quantization_config=quantization_config,
)
logger.info(
@@ -1218,7 +1241,9 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
# This is because the model would already be placed on a CUDA device.
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
- if is_loaded_in_8bit_bnb:
+ if is_loaded_in_8bit_bnb and (
+ is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0")
+ ):
logger.info(
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
)
@@ -1908,10 +1933,14 @@ def progress_bar(self, iterable=None, total=None):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
+ progress_bar_config = dict(self._progress_bar_config)
+ if "disable" not in progress_bar_config:
+ progress_bar_config["disable"] = not is_torch_dist_rank_zero()
+
if iterable is not None:
- return tqdm(iterable, **self._progress_bar_config)
+ return tqdm(iterable, **progress_bar_config)
elif total is not None:
- return tqdm(total=total, **self._progress_bar_config)
+ return tqdm(total=total, **progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 1d718a4852a4..2ecc13ef71bf 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -862,8 +862,12 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
index bb169ac5c443..f53d3c5630f0 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
@@ -806,8 +806,12 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
index 33dc2039b986..bc3ce84e1019 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
@@ -672,11 +672,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
- negative_txt_seq_lens = (
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
- )
-
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -695,7 +690,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -709,7 +703,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
index 5111096d93c1..ce6fc974a56e 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
@@ -909,7 +909,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
- txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)
@@ -920,7 +919,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
- txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
@@ -935,7 +933,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
index 102a813ab582..77d78a5ca7a1 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
@@ -852,7 +852,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
- txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)
@@ -863,7 +862,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
- txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
@@ -878,7 +876,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
index ed37b238c8c9..dd723460a59e 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
@@ -793,11 +793,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
- negative_txt_seq_lens = (
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
- )
-
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -821,7 +816,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -836,7 +830,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py
index d54d1881fa4e..cf467203a9d2 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py
@@ -1008,11 +1008,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
- negative_txt_seq_lens = (
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
- )
-
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1035,7 +1030,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -1050,7 +1044,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
index ec203edf166c..257e2d846c7c 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
@@ -663,6 +663,13 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]
+ # QwenImageEditPlusPipeline does not currently support batch_size > 1
+ if batch_size > 1:
+ raise ValueError(
+ f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. "
+ "Please process prompts one at a time."
+ )
+
device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
@@ -777,11 +784,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
- negative_txt_seq_lens = (
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
- )
-
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -805,7 +807,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -820,7 +821,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
index cb4c5d8016bb..e0b41b8b8799 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
@@ -775,11 +775,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
- negative_txt_seq_lens = (
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
- )
-
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -797,7 +792,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -811,7 +805,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py
index 1915c27eb2bb..83f02539b1ba 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py
@@ -944,11 +944,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
- negative_txt_seq_lens = (
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
- )
-
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -966,7 +961,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -980,7 +974,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
index 7bb12c26baa4..53d2c169ee63 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
@@ -781,10 +781,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
- negative_txt_seq_lens = (
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
- )
is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
@@ -809,7 +805,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
additional_t_cond=is_rgb,
return_dict=False,
@@ -825,7 +820,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
- txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
additional_t_cond=is_rgb,
return_dict=False,
@@ -885,7 +879,7 @@ def __call__(
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
- latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index 2beff802c6e0..33f9de7d20f0 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -927,8 +927,12 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
index 55ed7b84ebdf..9d5e17c2ed48 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
@@ -1010,8 +1010,12 @@ def __call__(
raise ValueError("`controlnet` must be of type `SanaControlNetModel`.")
# 5. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 6. Prepare latents.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
index 04f45f817efb..4c6d2247495d 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
@@ -790,10 +790,14 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
timesteps,
sigmas=None,
max_timesteps=max_timesteps,
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
index d6cd7d7feceb..1b1c8ee097c5 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
@@ -545,22 +545,24 @@ def __call__(
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
index 089f92632d38..4bc0d0aaea83 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
@@ -887,25 +887,28 @@ def __call__(
)
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- enable_diffusion_forcing=True,
- fps=fps_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
update_mask_i = step_update_mask[i]
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
index 2951a9447386..3e2004533258 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
@@ -966,25 +966,28 @@ def __call__(
)
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- enable_diffusion_forcing=True,
- fps=fps_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
update_mask_i = step_update_mask[i]
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py
index 6fedfc795a40..234ec531b862 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py
@@ -974,25 +974,28 @@ def __call__(
)
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- enable_diffusion_forcing=True,
- fps=fps_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
update_mask_i = step_update_mask[i]
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
index d61b687eadc3..d1df7f5f34cb 100644
--- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
@@ -678,24 +678,26 @@ def __call__(
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- encoder_hidden_states_image=image_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index cb97f18efeff..d079d2a225cf 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -989,8 +989,11 @@ def __call__(
)
# 4. Prepare timesteps
+ timestep_device = device
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
@@ -1093,6 +1096,8 @@ def __call__(
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+ if XLA_AVAILABLE:
+ xm.mark_step()
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 95d3ab06f02a..d0be0ee51317 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -1050,8 +1050,12 @@ def __call__(
image = self.image_processor.preprocess(image)
# 5. set timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 148d7386a732..82902cc7dcd0 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -1136,8 +1136,12 @@ def __call__(
)
# 4. set timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index 66d5ffa6b849..a1d0407caf5e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -459,7 +459,6 @@ def __call__(
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
>>> import torch
-
>>> pipeline = StableDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
... )
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index 660d9801df56..fcd108aef4c2 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -1025,10 +1025,14 @@ def __call__(
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
- device,
+ timestep_device,
sigmas=sigmas,
**scheduler_kwargs,
)
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index 9b11bc8781e7..e6ddbb5544c7 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -1047,8 +1047,12 @@ def __call__(
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas, **scheduler_kwargs
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index b947cbff0914..b1b30efc7da3 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -1167,8 +1167,13 @@ def __call__(
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
+
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
+ self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas, **scheduler_kwargs
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
# check that number of inference steps is not < 1 - as this doesn't make sense
diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
index 295095947a12..6d93e5feab4d 100644
--- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
@@ -1000,7 +1000,13 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index e969d2a21a99..3a63bb4f253a 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -1094,8 +1094,12 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 8d1da8dc102c..d1916b635f92 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -1264,8 +1264,12 @@ def __call__(
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 54a1e311804c..fcfddc192b8b 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -1399,8 +1399,12 @@ def __call__(
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
index 6d9053faaec8..633094239dca 100644
--- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
+++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
@@ -544,7 +544,13 @@ def __call__(
added_time_ids = added_time_ids.to(device)
# 6. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, None, sigmas
+ )
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 1ce6987114a7..7b6673cf16f7 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -848,8 +848,12 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 0ea3ba5046cf..bf089bf540ba 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -1130,8 +1130,12 @@ def __call__(
)
# 4. Prepare timesteps
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)
# 5. Prepare latent variables
diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py
index 78fe71ea9138..dc2bb471101d 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan.py
@@ -76,7 +76,8 @@
def basic_clean(text):
- text = ftfy.fix_text(text)
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
index a976126da7fe..5475b6e8b479 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
@@ -622,7 +622,13 @@ def __call__(
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if XLA_AVAILABLE:
+ timestep_device = "cpu"
+ else:
+ timestep_device = device
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, timestep_device, timesteps
+ )
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py
index 5e26862b018e..08fc4da0e7ba 100644
--- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py
+++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py
@@ -58,14 +58,13 @@
>>> # torch_dtype=torch.bfloat16,
>>> # )
- >>> # 2.0 - `config` is required
+ >>> # 2.0
>>> # controlnet = ZImageControlNetModel.from_single_file(
>>> # hf_hub_download(
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
>>> # ),
>>> # torch_dtype=torch.bfloat16,
- >>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # )
>>> pipe = ZImageControlNetPipeline.from_pretrained(
diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py
index 73ea7d0fddec..3b0f8dc288d3 100644
--- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py
@@ -50,14 +50,13 @@
... torch_dtype=torch.bfloat16,
... )
- >>> # 2.0 - `config` is required
+ >>> # 2.0
>>> # controlnet = ZImageControlNetModel.from_single_file(
>>> # hf_hub_download(
>>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors",
>>> # ),
>>> # torch_dtype=torch.bfloat16,
- >>> # config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
>>> # )
>>> pipe = ZImageControlNetInpaintPipeline.from_pretrained(
diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py
index 5dd8f56717df..cc5b5cc93d82 100644
--- a/src/diffusers/quantizers/quantization_config.py
+++ b/src/diffusers/quantizers/quantization_config.py
@@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,
- - **Floating point X-bit quantization:**
+ - **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
@@ -531,12 +531,18 @@ def post_init(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
- is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
- if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
+ is_floatx_quant_type = self.quant_type.startswith("fp")
+ is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
+ if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)
+ elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
+ f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
+ )
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
@@ -622,7 +628,6 @@ def _get_torchao_quant_type_to_method(cls):
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
- fpx_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
@@ -630,6 +635,8 @@ def _get_torchao_quant_type_to_method(cls):
uintx_weight_only,
)
+ if is_torchao_version("<=", "0.14.1"):
+ from torchao.quantization import fpx_weight_only
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor
@@ -650,18 +657,21 @@ def generate_float8dq_types(dtype: torch.dtype):
return types
def generate_fpx_quantization_types(bits: int):
- types = {}
+ if is_torchao_version("<=", "0.14.1"):
+ types = {}
- for ebits in range(1, bits):
- mbits = bits - ebits - 1
- types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
+ for ebits in range(1, bits):
+ mbits = bits - ebits - 1
+ types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
- non_sign_bits = bits - 1
- default_ebits = (non_sign_bits + 1) // 2
- default_mbits = non_sign_bits - default_ebits
- types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
+ non_sign_bits = bits - 1
+ default_ebits = (non_sign_bits + 1) // 2
+ default_mbits = non_sign_bits - default_ebits
+ types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
- return types
+ return types
+ else:
+ raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
@@ -710,15 +720,15 @@ def generate_fpx_quantization_types(bits: int):
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
- # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
- # fpx weight + bfloat16/float16 activation
- **generate_fpx_quantization_types(3),
- **generate_fpx_quantization_types(4),
- **generate_fpx_quantization_types(5),
- **generate_fpx_quantization_types(6),
- **generate_fpx_quantization_types(7),
}
+ if is_torchao_version("<=", "0.14.1"):
+ FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
+ FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
+ FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
+ FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
+ FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
+
UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 29052c1ba0cb..4199e75bf331 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -66,6 +66,7 @@
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
+ _import_structure["scheduling_ltx_euler_ancestral_rf"] = ["LTXEulerAncestralRFScheduler"]
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
@@ -168,6 +169,7 @@
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
from .scheduling_lcm import LCMScheduler
+ from .scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sasolver import SASolverScheduler
diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py
index 767fa9157f59..f4bd0cc2d74b 100644
--- a/src/diffusers/schedulers/scheduling_consistency_decoder.py
+++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py
@@ -40,6 +40,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
@@ -71,6 +78,22 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
+ """
+ A scheduler for the consistency decoder used in Stable Diffusion pipelines.
+
+ This scheduler implements a two-step denoising process using consistency models for decoding latent representations
+ into images.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, *optional*, defaults to `1024`):
+ The number of diffusion steps to train the model.
+ sigma_data (`float`, *optional*, defaults to `0.5`):
+ The standard deviation of the data distribution. Used for computing the skip and output scaling factors.
+ """
+
order = 1
@register_to_config
@@ -78,7 +101,7 @@ def __init__(
self,
num_train_timesteps: int = 1024,
sigma_data: float = 0.5,
- ):
+ ) -> None:
betas = betas_for_alpha_bar(num_train_timesteps)
alphas = 1.0 - betas
@@ -98,8 +121,18 @@ def __init__(
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
- device: Union[str, torch.device] = None,
- ):
+ device: Optional[Union[str, torch.device]] = None,
+ ) -> None:
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`, *optional*):
+ The number of diffusion steps used when generating samples with a pre-trained model. Currently, only
+ `2` inference steps are supported.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
if num_inference_steps != 2:
raise ValueError("Currently more than 2 inference steps are not supported.")
@@ -111,7 +144,15 @@ def set_timesteps(
self.c_in = self.c_in.to(device)
@property
- def init_noise_sigma(self):
+ def init_noise_sigma(self) -> torch.Tensor:
+ """
+ Return the standard deviation of the initial noise distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The initial noise sigma value from the precomputed `sqrt_one_minus_alphas_cumprod` at the first
+ timestep.
+ """
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
@@ -146,20 +187,20 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
- timestep (`float`):
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
- A random number generator.
+ A random number generator for reproducibility.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a
- [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
+ [`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`.
Returns:
- [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
- If return_dict is `True`,
- [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
+ [`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`:
+ If `return_dict` is `True`,
+ [`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
a tuple is returned where the first element is the sample tensor.
"""
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index 386a43db0f9c..195ff81b4c91 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -83,7 +83,7 @@ def __init__(
s_noise: float = 1.0,
rho: float = 7.0,
clip_denoised: bool = True,
- ):
+ ) -> None:
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
@@ -102,21 +102,29 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index, or `None` if not yet initialized.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+
+ Returns:
+ `int` or `None`:
+ The begin index, or `None` if not yet set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -151,7 +159,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample
- def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
+ def sigma_to_t(self, sigmas: Union[float, np.ndarray]) -> np.ndarray:
"""
Gets scaled timesteps from the Karras sigmas for input to the consistency model.
@@ -160,8 +168,8 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
A single Karras sigma or an array of Karras sigmas.
Returns:
- `float` or `np.ndarray`:
- A scaled input timestep or scaled input timestep array.
+ `np.ndarray`:
+ A scaled input timestep array.
"""
if not isinstance(sigmas, np.ndarray):
sigmas = np.array(sigmas, dtype=np.float64)
@@ -173,14 +181,14 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
- device: Union[str, torch.device] = None,
+ device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
- ):
+ ) -> None:
"""
Sets the timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -244,9 +252,19 @@ def set_timesteps(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument
- def _convert_to_karras(self, ramp):
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _convert_to_karras(self, ramp: np.ndarray) -> np.ndarray:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+ Args:
+ ramp (`np.ndarray`):
+ A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max.
+
+ Returns:
+ `np.ndarray`:
+ The Karras sigma schedule array.
+ """
sigma_min: float = self.config.sigma_min
sigma_max: float = self.config.sigma_max
@@ -256,14 +274,25 @@ def _convert_to_karras(self, ramp):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
- def get_scalings(self, sigma):
+ def get_scalings(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Computes the scaling factors for the consistency model output.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The current sigma value in the noise schedule.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output).
+ """
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out
- def get_scalings_for_boundary_condition(self, sigma):
+ def get_scalings_for_boundary_condition(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Gets the scalings used in the consistency model parameterization (from Appendix C of the
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
@@ -275,7 +304,7 @@ def get_scalings_for_boundary_condition(self, sigma):
The current sigma in the Karras sigma schedule.
Returns:
- `tuple`:
+ `Tuple[torch.Tensor, torch.Tensor]`:
A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
(which weights the consistency model output) is the second element.
"""
@@ -348,13 +377,13 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
- timestep (`float`):
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
- return_dict (`bool`, *optional*, defaults to `True`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`.
@@ -406,7 +435,10 @@ def step(
# Noise is not used for onestep sampling.
if len(self.timesteps) > 1:
noise = randn_tensor(
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
+ model_output.shape,
+ dtype=model_output.dtype,
+ device=model_output.device,
+ generator=generator,
)
else:
noise = torch.zeros_like(model_output)
@@ -475,5 +507,12 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
+ """
+ Returns the number of training timesteps.
+
+ Returns:
+ `int`:
+ The number of training timesteps configured for the scheduler.
+ """
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
index 103cca81c6a5..9c6b0fcf69b6 100644
--- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -36,27 +36,30 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- sigma_min (`float`, *optional*, defaults to 0.3):
+ sigma_min (`float`, defaults to `0.3`):
Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
- sigma_max (`float`, *optional*, defaults to 500):
+ sigma_max (`float`, defaults to `500`):
Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
- sigma_data (`float`, *optional*, defaults to 1.0):
+ sigma_data (`float`, defaults to `1.0`):
The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
- sigma_schedule (`str`, *optional*, defaults to `exponential`):
- Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
- (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
- schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
- num_train_timesteps (`int`, defaults to 1000):
+ sigma_schedule (`str`, defaults to `"exponential"`):
+ Sigma schedule to compute the `sigmas`. Must be one of `"exponential"` or `"karras"`. The exponential
+ schedule was incorporated in [stabilityai/cosxl](https://huggingface.co/stabilityai/cosxl). The Karras
+ schedule is introduced in the [EDM](https://huggingface.co/papers/2206.00364) paper.
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- solver_order (`int`, defaults to 2):
+ solver_order (`int`, defaults to `2`):
The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
- prediction_type (`str`, defaults to `v_prediction`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ prediction_type (`str`, defaults to `"v_prediction"`):
+ Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
+ process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
- solver_type (`str`, defaults to `midpoint`):
- Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
- sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ rho (`float`, defaults to `7.0`):
+ The parameter for calculating the Karras sigma schedule from the EDM
+ [paper](https://huggingface.co/papers/2206.00364).
+ solver_type (`str`, defaults to `"midpoint"`):
+ Solver type for the second-order solver. Must be one of `"midpoint"` or `"heun"`. The solver type slightly
+ affects the sample quality, especially for a small number of steps. It is recommended to use `"midpoint"`.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
@@ -65,8 +68,9 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
final_sigmas_type (`str`, defaults to `"zero"`):
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ The final `sigma` value for the noise schedule during the sampling process. Must be one of `"zero"` or
+ `"sigma_min"`. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If
+ `"zero"`, the final sigma is set to 0.
"""
_compatibles = []
@@ -78,16 +82,16 @@ def __init__(
sigma_min: float = 0.3,
sigma_max: float = 500,
sigma_data: float = 1.0,
- sigma_schedule: str = "exponential",
+ sigma_schedule: Literal["exponential", "karras"] = "exponential",
num_train_timesteps: int = 1000,
solver_order: int = 2,
- prediction_type: str = "v_prediction",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "v_prediction",
rho: float = 7.0,
- solver_type: str = "midpoint",
+ solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
- ):
+ final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
+ ) -> None:
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
@@ -113,26 +117,40 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def init_noise_sigma(self):
- # standard deviation of the initial noise distribution
+ def init_noise_sigma(self) -> float:
+ """
+ The standard deviation of the initial noise distribution.
+
+ Returns:
+ `float`:
+ The initial noise sigma value computed as `sqrt(sigma_max^2 + 1)`.
+ """
return (self.config.sigma_max**2 + 1) ** 0.5
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index, or `None` if not yet initialized.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+
+ Returns:
+ `int` or `None`:
+ The begin index, or `None` if not yet set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -143,19 +161,63 @@ def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
- def precondition_inputs(self, sample, sigma):
+ def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Precondition the input sample by scaling it according to the EDM formulation.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample tensor to precondition.
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `torch.Tensor`:
+ The scaled input sample.
+ """
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
- def precondition_noise(self, sigma):
+ def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Precondition the noise level by computing a normalized timestep representation.
+
+ Args:
+ sigma (`float` or `torch.Tensor`):
+ The sigma (noise level) value to precondition.
+
+ Returns:
+ `torch.Tensor`:
+ The preconditioned noise value computed as `atan(sigma) / pi * 2`.
+ """
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
return sigma.atan() / math.pi * 2
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
- def precondition_outputs(self, sample, model_output, sigma):
+ def precondition_outputs(
+ self,
+ sample: torch.Tensor,
+ model_output: torch.Tensor,
+ sigma: Union[float, torch.Tensor],
+ ) -> torch.Tensor:
+ """
+ Precondition the model outputs according to the EDM formulation.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample tensor.
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `torch.Tensor`:
+ The denoised sample computed by combining the skip connection and output scaling.
+ """
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -173,13 +235,13 @@ def precondition_outputs(self, sample, model_output, sigma):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
- current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+ Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
+ need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
+ The input sample tensor.
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -195,12 +257,14 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample
- def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
+ def set_timesteps(
+ self, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -242,8 +306,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
self.noise_sampler = None
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
- def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _compute_karras_sigmas(
+ self,
+ ramp: torch.Tensor,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ ) -> torch.Tensor:
+ """
+ Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ ramp (`torch.Tensor`):
+ A tensor of values in [0, 1] representing the interpolation positions.
+ sigma_min (`float`, *optional*):
+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
+ sigma_max (`float`, *optional*):
+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
+
+ Returns:
+ `torch.Tensor`:
+ The computed Karras sigma schedule.
+ """
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -254,10 +337,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
- def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
- """Implementation closely follows k-diffusion.
-
+ def _compute_exponential_sigmas(
+ self,
+ ramp: torch.Tensor,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ ) -> torch.Tensor:
+ """
+ Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
+
+ Args:
+ ramp (`torch.Tensor`):
+ A tensor of values representing the interpolation positions.
+ sigma_min (`float`, *optional*):
+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
+ sigma_max (`float`, *optional*):
+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
+
+ Returns:
+ `torch.Tensor`:
+ The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -265,7 +365,7 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -301,7 +401,19 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t
- def _sigma_to_alpha_sigma_t(self, sigma):
+ def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert sigma to alpha and sigma_t values for the diffusion process.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The sigma (noise level) value.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing `alpha_t` (always 1 since inputs are pre-scaled) and `sigma_t` (same as input
+ sigma).
+ """
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma
@@ -354,7 +466,10 @@ def dpm_solver_first_order_update(
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
- sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
+ sigma_t, sigma_s = (
+ self.sigmas[self.step_index + 1],
+ self.sigmas[self.step_index],
+ )
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -464,7 +579,7 @@ def index_for_timestep(
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
@@ -485,7 +600,7 @@ def step(
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
- generator=None,
+ generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -495,20 +610,19 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`int`):
+ timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -540,7 +654,10 @@ def step(
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
)
self.noise_sampler = BrownianTreeNoiseSampler(
- model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
+ model_output,
+ sigma_min=self.config.sigma_min,
+ sigma_max=self.config.sigma_max,
+ seed=seed,
)
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
model_output.device
@@ -612,9 +729,27 @@ def add_noise(
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
- def _get_conditioning_c_in(self, sigma):
+ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
+ """
+ Compute the input conditioning factor for the EDM formulation.
+
+ Args:
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `float` or `torch.Tensor`:
+ The input conditioning factor `c_in`.
+ """
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
- def __len__(self):
+ def __len__(self) -> int:
+ """
+ Returns the number of training timesteps.
+
+ Returns:
+ `int`:
+ The number of training timesteps configured for the scheduler.
+ """
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py
index d7fe29a72ac9..74ade1d8bb86 100644
--- a/src/diffusers/schedulers/scheduling_ddim.py
+++ b/src/diffusers/schedulers/scheduling_ddim.py
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
index f2683d1304ec..92f7a5ab3a04 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py
index 8ae13ad49d10..e76ad9aa6ccb 100644
--- a/src/diffusers/schedulers/scheduling_ddim_inverse.py
+++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py
@@ -75,6 +75,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py
index 10873a082fee..09f55ee4c24e 100644
--- a/src/diffusers/schedulers/scheduling_ddim_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index ded88b8e1e0a..d0596bb918e9 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -74,6 +74,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
@@ -207,6 +214,8 @@ def __init__(
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "laplace":
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index 941fc16be080..ee7ab66be4c3 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -76,6 +76,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
@@ -217,6 +224,8 @@ def __init__(
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "laplace":
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index b7d64fc00bae..ebc3a33b27d3 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
index 0a9082208cf4..66fb39c0bc4d 100644
--- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
+++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index e7ba0ba1f30e..990129f5847d 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index 6696b0375f9f..a9c4fe57b68a 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
index 81c9e4134f57..5f9ce1393d83 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
@@ -143,6 +143,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 4916e1abb549..e92f880e5b85 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -62,6 +62,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
index d4e8ca5e8b18..a573f032cad8 100644
--- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
@@ -175,13 +175,37 @@ def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
- def precondition_inputs(self, sample, sigma):
+ def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Precondition the input sample by scaling it according to the EDM formulation.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample tensor to precondition.
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `torch.Tensor`:
+ The scaled input sample.
+ """
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
- def precondition_noise(self, sigma):
+ def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Precondition the noise level by applying a logarithmic transformation.
+
+ Args:
+ sigma (`float` or `torch.Tensor`):
+ The sigma (noise level) value to precondition.
+
+ Returns:
+ `torch.Tensor`:
+ The preconditioned noise value computed as `0.25 * log(sigma)`.
+ """
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
@@ -190,7 +214,27 @@ def precondition_noise(self, sigma):
return c_noise
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
- def precondition_outputs(self, sample, model_output, sigma):
+ def precondition_outputs(
+ self,
+ sample: torch.Tensor,
+ model_output: torch.Tensor,
+ sigma: Union[float, torch.Tensor],
+ ) -> torch.Tensor:
+ """
+ Precondition the model outputs according to the EDM formulation.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample tensor.
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `torch.Tensor`:
+ The denoised sample computed by combining the skip connection and output scaling.
+ """
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -208,13 +252,13 @@ def precondition_outputs(self, sample, model_output, sigma):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
- current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+ Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
+ need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
+ The input sample tensor.
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -274,8 +318,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
- def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _compute_karras_sigmas(
+ self,
+ ramp: torch.Tensor,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ ) -> torch.Tensor:
+ """
+ Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ ramp (`torch.Tensor`):
+ A tensor of values in [0, 1] representing the interpolation positions.
+ sigma_min (`float`, *optional*):
+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
+ sigma_max (`float`, *optional*):
+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
+
+ Returns:
+ `torch.Tensor`:
+ The computed Karras sigma schedule.
+ """
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -286,10 +349,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
- def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
- """Implementation closely follows k-diffusion.
-
+ def _compute_exponential_sigmas(
+ self,
+ ramp: torch.Tensor,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ ) -> torch.Tensor:
+ """
+ Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
+
+ Args:
+ ramp (`torch.Tensor`):
+ A tensor of values representing the interpolation positions.
+ sigma_min (`float`, *optional*):
+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
+ sigma_max (`float`, *optional*):
+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
+
+ Returns:
+ `torch.Tensor`:
+ The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -433,7 +513,10 @@ def dpm_solver_first_order_update(
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
- sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
+ sigma_t, sigma_s = (
+ self.sigmas[self.step_index + 1],
+ self.sigmas[self.step_index],
+ )
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -684,7 +767,10 @@ def step(
if self.config.algorithm_type == "sde-dpmsolver++":
noise = randn_tensor(
- model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=model_output.dtype,
)
else:
noise = None
@@ -757,7 +843,18 @@ def add_noise(
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
- def _get_conditioning_c_in(self, sigma):
+ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
+ """
+ Compute the input conditioning factor for the EDM formulation.
+
+ Args:
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `float` or `torch.Tensor`:
+ The input conditioning factor `c_in`.
+ """
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py
index 2ed05d396514..604d8b3ea6fa 100644
--- a/src/diffusers/schedulers/scheduling_edm_euler.py
+++ b/src/diffusers/schedulers/scheduling_edm_euler.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import torch
@@ -57,29 +57,28 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- sigma_min (`float`, *optional*, defaults to 0.002):
+ sigma_min (`float`, *optional*, defaults to `0.002`):
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
range is [0, 10].
- sigma_max (`float`, *optional*, defaults to 80.0):
+ sigma_max (`float`, *optional*, defaults to `80.0`):
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
range is [0.2, 80.0].
- sigma_data (`float`, *optional*, defaults to 0.5):
+ sigma_data (`float`, *optional*, defaults to `0.5`):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
- sigma_schedule (`str`, *optional*, defaults to `karras`):
- Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
- (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
- schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
- num_train_timesteps (`int`, defaults to 1000):
+ sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`):
+ Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper
+ (https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model:
+ https://huggingface.co/stabilityai/cosxl.
+ num_train_timesteps (`int`, *optional*, defaults to `1000`):
The number of diffusion steps to train the model.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://huggingface.co/papers/2210.02303) paper).
- rho (`float`, *optional*, defaults to 7.0):
+ prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`):
+ Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and
+ `"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
+ rho (`float`, *optional*, defaults to `7.0`):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
- final_sigmas_type (`str`, defaults to `"zero"`):
+ final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
"""
_compatibles = []
@@ -91,12 +90,12 @@ def __init__(
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
- sigma_schedule: str = "karras",
+ sigma_schedule: Literal["karras", "exponential"] = "karras",
num_train_timesteps: int = 1000,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
rho: float = 7.0,
- final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
- ):
+ final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
+ ) -> None:
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
@@ -131,26 +130,41 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def init_noise_sigma(self):
- # standard deviation of the initial noise distribution
+ def init_noise_sigma(self) -> float:
+ """
+ Return the standard deviation of the initial noise distribution.
+
+ Returns:
+ `float`:
+ The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`.
+ """
return (self.config.sigma_max**2 + 1) ** 0.5
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
- The index counter for current timestep. It will increase 1 after each scheduler step.
+ Return the index counter for the current timestep. The index will increase by 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index, or `None` if not yet initialized.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index`
+ method.
+
+ Returns:
+ `int` or `None`:
+ The begin index, or `None` if not yet set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -160,12 +174,36 @@ def set_begin_index(self, begin_index: int = 0):
"""
self._begin_index = begin_index
- def precondition_inputs(self, sample, sigma):
+ def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Precondition the input sample by scaling it according to the EDM formulation.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample tensor to precondition.
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `torch.Tensor`:
+ The scaled input sample.
+ """
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
- def precondition_noise(self, sigma):
+ def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Precondition the noise level by applying a logarithmic transformation.
+
+ Args:
+ sigma (`float` or `torch.Tensor`):
+ The sigma (noise level) value to precondition.
+
+ Returns:
+ `torch.Tensor`:
+ The preconditioned noise value computed as `0.25 * log(sigma)`.
+ """
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
@@ -173,7 +211,27 @@ def precondition_noise(self, sigma):
return c_noise
- def precondition_outputs(self, sample, model_output, sigma):
+ def precondition_outputs(
+ self,
+ sample: torch.Tensor,
+ model_output: torch.Tensor,
+ sigma: Union[float, torch.Tensor],
+ ) -> torch.Tensor:
+ """
+ Precondition the model outputs according to the EDM formulation.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample tensor.
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `torch.Tensor`:
+ The denoised sample computed by combining the skip connection and output scaling.
+ """
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -190,13 +248,13 @@ def precondition_outputs(self, sample, model_output, sigma):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
- current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+ Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
+ need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
+ The input sample tensor.
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -214,19 +272,19 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
def set_timesteps(
self,
- num_inference_steps: int = None,
- device: Union[str, torch.device] = None,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
- ):
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
+ sigmas (`torch.Tensor` or `List[float]`, *optional*):
Custom sigmas to use for the denoising process. If not defined, the default behavior when
`num_inference_steps` is passed will be used.
"""
@@ -262,8 +320,27 @@ def set_timesteps(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
- def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _compute_karras_sigmas(
+ self,
+ ramp: torch.Tensor,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ ) -> torch.Tensor:
+ """
+ Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ ramp (`torch.Tensor`):
+ A tensor of values in [0, 1] representing the interpolation positions.
+ sigma_min (`float`, *optional*):
+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
+ sigma_max (`float`, *optional*):
+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
+
+ Returns:
+ `torch.Tensor`:
+ The computed Karras sigma schedule.
+ """
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -273,10 +350,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
- def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
- """Implementation closely follows k-diffusion.
-
+ def _compute_exponential_sigmas(
+ self,
+ ramp: torch.Tensor,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ ) -> torch.Tensor:
+ """
+ Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
+
+ Args:
+ ramp (`torch.Tensor`):
+ A tensor of values representing the interpolation positions.
+ sigma_min (`float`, *optional*):
+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
+ sigma_max (`float`, *optional*):
+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
+
+ Returns:
+ `torch.Tensor`:
+ The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -342,32 +436,38 @@ def step(
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
pred_original_sample: Optional[torch.Tensor] = None,
- ) -> Union[EDMEulerSchedulerOutput, Tuple]:
+ ) -> Union[EDMEulerSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
- The direct output from learned diffusion model.
- timestep (`float`):
+ The direct output from the learned diffusion model.
+ timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- s_churn (`float`):
- s_tmin (`float`):
- s_tmax (`float`):
- s_noise (`float`, defaults to 1.0):
+ s_churn (`float`, *optional*, defaults to `0.0`):
+ The amount of stochasticity to add at each step. Higher values add more noise.
+ s_tmin (`float`, *optional*, defaults to `0.0`):
+ The minimum sigma threshold below which no noise is added.
+ s_tmax (`float`, *optional*, defaults to `float("inf")`):
+ The maximum sigma threshold above which no noise is added.
+ s_noise (`float`, *optional*, defaults to `1.0`):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
- A random number generator.
- return_dict (`bool`):
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
+ A random number generator for reproducibility.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple.
+ pred_original_sample (`torch.Tensor`, *optional*):
+ The predicted denoised sample from a previous step. If provided, skips recomputation.
Returns:
- [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
- returned, otherwise a tuple is returned where the first element is the sample tensor.
+ [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`:
+ If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the previous sample tensor and the
+ second element is the predicted original sample tensor.
"""
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
@@ -399,7 +499,10 @@ def step(
if gamma > 0:
noise = randn_tensor(
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
+ model_output.shape,
+ dtype=model_output.dtype,
+ device=model_output.device,
+ generator=generator,
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
@@ -478,9 +581,20 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
- def _get_conditioning_c_in(self, sigma):
+ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
+ """
+ Compute the input conditioning factor for the EDM formulation.
+
+ Args:
+ sigma (`float` or `torch.Tensor`):
+ The current sigma (noise level) value.
+
+ Returns:
+ `float` or `torch.Tensor`:
+ The input conditioning factor `c_in`.
+ """
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
index 97fd84db5621..0258ea777747 100644
--- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index a55a76626cec..4238c976e4d6 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -80,6 +80,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 9fd61d9e18d1..378a62ca8aee 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -171,8 +171,8 @@ def set_shift(self, shift: float):
def scale_noise(
self,
sample: torch.FloatTensor,
- timestep: Union[float, torch.FloatTensor],
- noise: Optional[torch.FloatTensor] = None,
+ timestep: torch.FloatTensor,
+ noise: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
@@ -180,8 +180,10 @@ def scale_noise(
Args:
sample (`torch.FloatTensor`):
The input sample.
- timestep (`int`, *optional*):
+ timestep (`torch.FloatTensor`):
The current timestep in the diffusion chain.
+ noise (`torch.FloatTensor`):
+ The noise tensor.
Returns:
`torch.FloatTensor`:
diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
index 6febee444c5a..6b85194f8b5e 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
@@ -110,8 +110,8 @@ def set_begin_index(self, begin_index: int = 0):
def scale_noise(
self,
sample: torch.FloatTensor,
- timestep: Union[float, torch.FloatTensor],
- noise: Optional[torch.FloatTensor] = None,
+ timestep: torch.FloatTensor,
+ noise: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
@@ -119,8 +119,10 @@ def scale_noise(
Args:
sample (`torch.FloatTensor`):
The input sample.
- timestep (`int`, *optional*):
+ timestep (`torch.FloatTensor`):
The current timestep in the diffusion chain.
+ noise (`torch.FloatTensor`):
+ The noise tensor.
Returns:
`torch.FloatTensor`:
@@ -130,6 +132,7 @@ def scale_noise(
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
+
sample = sigma * noise + (1.0 - sigma) * sample
return sample
diff --git a/src/diffusers/schedulers/scheduling_flow_match_lcm.py b/src/diffusers/schedulers/scheduling_flow_match_lcm.py
index 25186d1fe969..8ef0e2ec8175 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_lcm.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_lcm.py
@@ -192,8 +192,8 @@ def set_scale_factors(self, scale_factors: list, upscale_mode):
def scale_noise(
self,
sample: torch.FloatTensor,
- timestep: Union[float, torch.FloatTensor],
- noise: Optional[torch.FloatTensor] = None,
+ timestep: torch.FloatTensor,
+ noise: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
@@ -201,8 +201,10 @@ def scale_noise(
Args:
sample (`torch.FloatTensor`):
The input sample.
- timestep (`int`, *optional*):
+ timestep (`torch.FloatTensor`):
The current timestep in the diffusion chain.
+ noise (`torch.FloatTensor`):
+ The noise tensor.
Returns:
`torch.FloatTensor`:
diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py
index b113f9b49832..011f97ba5c57 100644
--- a/src/diffusers/schedulers/scheduling_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_heun_discrete.py
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
index da40bed635e1..37849e28b23c 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
index 6dc08d4d0a86..1c2791837cca 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 0527f3533851..66dedd5a6eab 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -79,6 +79,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index 276af6eeacb7..9fc9b1e64b3f 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -75,6 +75,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py b/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py
new file mode 100644
index 000000000000..6710254f4445
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py
@@ -0,0 +1,386 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+LTXEulerAncestralRFScheduler
+
+This scheduler implements a K-diffusion style Euler-Ancestral sampler specialized for flow / CONST parameterization,
+closely mirroring ComfyUI's `sample_euler_ancestral_RF` implementation used for LTX-Video.
+
+Reference implementation (ComfyUI):
+ comfy.k_diffusion.sampling.sample_euler_ancestral_RF
+"""
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, logging
+from ..utils.torch_utils import randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class LTXEulerAncestralRFSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor`):
+ Updated sample for the next step in the denoising process.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class LTXEulerAncestralRFScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler-Ancestral scheduler for LTX-Video (RF / CONST parametrization).
+
+ This scheduler is intended for models where the network is trained with a CONST-like parameterization (as in LTXV /
+ FLUX). It approximates ComfyUI's `sample_euler_ancestral_RF` sampler and is useful when reproducing ComfyUI
+ workflows inside diffusers.
+
+ The scheduler can either:
+ - reuse the [`FlowMatchEulerDiscreteScheduler`] sigma / timestep logic when only `num_inference_steps` is provided
+ (default diffusers-style usage), or
+ - follow an explicit ComfyUI-style sigma schedule when `sigmas` (or `timesteps`) are passed to [`set_timesteps`].
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ Included for config compatibility; not used to build the schedule.
+ eta (`float`, defaults to 1.0):
+ Stochasticity parameter. `eta=0.0` yields deterministic DDIM-like sampling; `eta=1.0` matches ComfyUI's
+ default RF behavior.
+ s_noise (`float`, defaults to 1.0):
+ Global scaling factor for the stochastic noise term.
+ """
+
+ # Allow config migration from the flow-match scheduler and back.
+ _compatibles = ["FlowMatchEulerDiscreteScheduler"]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ eta: float = 1.0,
+ s_noise: float = 1.0,
+ ):
+ # Note: num_train_timesteps is kept only for config compatibility.
+ self.num_inference_steps: Optional[int] = None
+ self.sigmas: Optional[torch.Tensor] = None
+ self.timesteps: Optional[torch.Tensor] = None
+ self._step_index: Optional[int] = None
+ self._begin_index: Optional[int] = None
+
+ @property
+ def step_index(self) -> Optional[int]:
+ return self._step_index
+
+ @property
+ def begin_index(self) -> Optional[int]:
+ """
+ The index for the first timestep. It can be set from a pipeline with `set_begin_index` to support
+ image-to-image like workflows that start denoising part-way through the schedule.
+ """
+ return self._begin_index
+
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Included for API compatibility; not strictly needed here but kept to allow pipelines that call
+ `set_begin_index`.
+ """
+ self._begin_index = begin_index
+
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Map a (continuous) `timestep` value to an index into `self.timesteps`.
+
+ This follows the convention used in other discrete schedulers: if the same timestep value appears multiple
+ times in the schedule (which can happen when starting in the middle of the schedule), the *second* occurrence
+ is used for the first `step` call so that no sigma is accidentally skipped.
+ """
+ if schedule_timesteps is None:
+ if self.timesteps is None:
+ raise ValueError("Timesteps have not been set. Call `set_timesteps` first.")
+ schedule_timesteps = self.timesteps
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(schedule_timesteps.device)
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ if len(indices) == 0:
+ raise ValueError(
+ "Passed `timestep` is not in `self.timesteps`. Make sure to use values from `scheduler.timesteps`."
+ )
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]):
+ """
+ Initialize the internal step index based on a given timestep.
+ """
+ if self.timesteps is None:
+ raise ValueError("Timesteps have not been set. Call `set_timesteps` first.")
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device, None] = None,
+ sigmas: Optional[Union[List[float], torch.Tensor]] = None,
+ timesteps: Optional[Union[List[float], torch.Tensor]] = None,
+ mu: Optional[float] = None,
+ **kwargs,
+ ):
+ """
+ Set the sigma / timestep schedule for sampling.
+
+ When `sigmas` or `timesteps` are provided explicitly, they are used as the RF sigma schedule (ComfyUI-style)
+ and are expected to include the terminal 0.0. When both are `None`, the scheduler reuses the
+ [`FlowMatchEulerDiscreteScheduler`] logic to generate sigmas from `num_inference_steps` and the stored config
+ (including any resolution-dependent shifting, Karras/beta schedules, etc.).
+
+ Args:
+ num_inference_steps (`int`, *optional*):
+ Number of denoising steps. If provided together with explicit `sigmas`/`timesteps`, they are expected
+ to be consistent and are otherwise ignored with a warning.
+ device (`str` or `torch.device`, *optional*):
+ Device to move the internal tensors to.
+ sigmas (`List[float]` or `torch.Tensor`, *optional*):
+ Explicit sigma schedule, e.g. `[1.0, 0.99, ..., 0.0]`.
+ timesteps (`List[float]` or `torch.Tensor`, *optional*):
+ Optional alias for `sigmas`. If `sigmas` is None and `timesteps` is provided, timesteps are treated as
+ sigmas.
+ mu (`float`, *optional*):
+ Optional shift parameter used when delegating to [`FlowMatchEulerDiscreteScheduler.set_timesteps`] and
+ `config.use_dynamic_shifting` is `True`.
+ """
+ # 1. Auto-generate schedule (FlowMatch-style) when no explicit sigmas/timesteps are given
+ if sigmas is None and timesteps is None:
+ if num_inference_steps is None:
+ raise ValueError(
+ "LTXEulerAncestralRFScheduler.set_timesteps requires either explicit `sigmas`/`timesteps` "
+ "or a `num_inference_steps` value."
+ )
+
+ # We reuse FlowMatchEulerDiscreteScheduler to construct a sigma schedule that is
+ # consistent with the original LTX training setup (including optional time shifting,
+ # Karras / exponential / beta schedules, etc.).
+ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+
+ base_scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.config)
+ base_scheduler.set_timesteps(
+ num_inference_steps=num_inference_steps,
+ device=device,
+ sigmas=None,
+ mu=mu,
+ timesteps=None,
+ )
+
+ self.num_inference_steps = base_scheduler.num_inference_steps
+ # Keep sigmas / timesteps on the requested device so step() can operate on-device without
+ # extra transfers.
+ self.sigmas = base_scheduler.sigmas.to(device=device)
+ self.timesteps = base_scheduler.timesteps.to(device=device)
+ self._step_index = None
+ self._begin_index = None
+ return
+
+ # 2. Explicit sigma schedule (ComfyUI-style path)
+ if sigmas is None:
+ # `timesteps` is treated as sigmas in RF / flow-matching setups.
+ sigmas = timesteps
+
+ if isinstance(sigmas, list):
+ sigmas_tensor = torch.tensor(sigmas, dtype=torch.float32)
+ elif isinstance(sigmas, torch.Tensor):
+ sigmas_tensor = sigmas.to(dtype=torch.float32)
+ else:
+ raise TypeError(f"`sigmas` must be a list or torch.Tensor, got {type(sigmas)}.")
+
+ if sigmas_tensor.ndim != 1:
+ raise ValueError(f"`sigmas` must be a 1D tensor, got shape {tuple(sigmas_tensor.shape)}.")
+
+ if sigmas_tensor[-1].abs().item() > 1e-6:
+ logger.warning(
+ "The last sigma in the schedule is not zero (%.6f). "
+ "For best compatibility with ComfyUI's RF sampler, the terminal sigma "
+ "should be 0.0.",
+ sigmas_tensor[-1].item(),
+ )
+
+ # Move to device once, then derive timesteps.
+ if device is not None:
+ sigmas_tensor = sigmas_tensor.to(device)
+
+ # Internal sigma schedule stays in [0, 1] (as provided).
+ self.sigmas = sigmas_tensor
+ # Timesteps are scaled to match the training setup of LTX (FlowMatch-style),
+ # where the network expects timesteps on [0, num_train_timesteps].
+ # This keeps the transformer conditioning in the expected range while the RF
+ # scheduler still operates on the raw sigma values.
+ num_train = float(getattr(self.config, "num_train_timesteps", 1000))
+ self.timesteps = sigmas_tensor * num_train
+
+ if num_inference_steps is not None and num_inference_steps != len(sigmas) - 1:
+ logger.warning(
+ "Provided `num_inference_steps=%d` does not match `len(sigmas)-1=%d`. "
+ "Overriding `num_inference_steps` with `len(sigmas)-1`.",
+ num_inference_steps,
+ len(sigmas) - 1,
+ )
+
+ self.num_inference_steps = len(sigmas) - 1
+ self._step_index = None
+ self._begin_index = None
+
+ def _sigma_broadcast(self, sigma: torch.Tensor, sample: torch.Tensor) -> torch.Tensor:
+ """
+ Helper to broadcast a scalar sigma to the shape of `sample`.
+ """
+ while sigma.ndim < sample.ndim:
+ sigma = sigma.view(*sigma.shape, 1)
+ return sigma
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.Tensor],
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[LTXEulerAncestralRFSchedulerOutput, Tuple[torch.FloatTensor]]:
+ """
+ Perform a single Euler-Ancestral RF update step.
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ Raw model output at the current step. Interpreted under the CONST parametrization as `v_t`, with
+ denoised state reconstructed as `x0 = x_t - sigma_t * v_t`.
+ timestep (`float` or `torch.Tensor`):
+ The current sigma value (must match one entry in `self.timesteps`).
+ sample (`torch.FloatTensor`):
+ Current latent sample `x_t`.
+ generator (`torch.Generator`, *optional*):
+ Optional generator for reproducible noise.
+ return_dict (`bool`):
+ If `True`, return a `LTXEulerAncestralRFSchedulerOutput`; otherwise return a tuple where the first
+ element is the updated sample.
+ """
+
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `LTXEulerAncestralRFScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` values as `timestep`."
+ ),
+ )
+
+ if self.sigmas is None or self.timesteps is None:
+ raise ValueError("Scheduler has not been initialized. Call `set_timesteps` before `step`.")
+
+ if self._step_index is None:
+ self._init_step_index(timestep)
+
+ i = self._step_index
+ if i >= len(self.sigmas) - 1:
+ # Already at the end; simply return the current sample.
+ prev_sample = sample
+ else:
+ # Work in float32 for numerical stability
+ sample_f = sample.to(torch.float32)
+ model_output_f = model_output.to(torch.float32)
+
+ sigma = self.sigmas[i]
+ sigma_next = self.sigmas[i + 1]
+
+ sigma_b = self._sigma_broadcast(sigma.view(1), sample_f)
+ sigma_next_b = self._sigma_broadcast(sigma_next.view(1), sample_f)
+
+ # Approximate denoised x0 under CONST parametrization:
+ # x0 = x_t - sigma_t * v_t
+ denoised = sample_f - sigma_b * model_output_f
+
+ if sigma_next.abs().item() < 1e-8:
+ # Final denoising step
+ x = denoised
+ else:
+ eta = float(self.config.eta)
+ s_noise = float(self.config.s_noise)
+
+ # Downstep computation (ComfyUI RF variant)
+ downstep_ratio = 1.0 + (sigma_next / sigma - 1.0) * eta
+ sigma_down = sigma_next * downstep_ratio
+
+ alpha_ip1 = 1.0 - sigma_next
+ alpha_down = 1.0 - sigma_down
+
+ # Deterministic part (Euler step in (x, x0)-space)
+ sigma_down_b = self._sigma_broadcast(sigma_down.view(1), sample_f)
+ alpha_ip1_b = self._sigma_broadcast(alpha_ip1.view(1), sample_f)
+ alpha_down_b = self._sigma_broadcast(alpha_down.view(1), sample_f)
+
+ sigma_ratio = sigma_down_b / sigma_b
+ x = sigma_ratio * sample_f + (1.0 - sigma_ratio) * denoised
+
+ # Stochastic ancestral noise
+ if eta > 0.0 and s_noise > 0.0:
+ renoise_coeff = (
+ (sigma_next_b**2 - sigma_down_b**2 * alpha_ip1_b**2 / (alpha_down_b**2 + 1e-12))
+ .clamp(min=0.0)
+ .sqrt()
+ )
+
+ noise = randn_tensor(
+ sample_f.shape, generator=generator, device=sample_f.device, dtype=sample_f.dtype
+ )
+ x = (alpha_ip1_b / (alpha_down_b + 1e-12)) * x + noise * renoise_coeff * s_noise
+
+ prev_sample = x.to(sample.dtype)
+
+ # Advance internal step index
+ self._step_index = min(self._step_index + 1, len(self.sigmas) - 1)
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return LTXEulerAncestralRFSchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self) -> int:
+ # For compatibility with other schedulers; used e.g. in some training
+ # utilities to infer the maximum number of training timesteps.
+ return int(getattr(self.config, "num_train_timesteps", 1000))
diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py
index 651532b06ddb..e95a374457e4 100644
--- a/src/diffusers/schedulers/scheduling_pndm.py
+++ b/src/diffusers/schedulers/scheduling_pndm.py
@@ -54,6 +54,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py
index a2eaf8eb3abd..fcebe7e21c1d 100644
--- a/src/diffusers/schedulers/scheduling_repaint.py
+++ b/src/diffusers/schedulers/scheduling_repaint.py
@@ -73,6 +73,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index 5783e20de69d..7c679a255c39 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -61,6 +61,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index 7b4840ffdb19..7a385f62918b 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py
index 5a978dec649b..bdc4feb0b197 100644
--- a/src/diffusers/schedulers/scheduling_unclip.py
+++ b/src/diffusers/schedulers/scheduling_unclip.py
@@ -74,6 +74,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 5ea56b300be2..0536e8d1ed7a 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == "laplace":
+
+ def alpha_bar_fn(t):
+ lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
+ snr = math.exp(lmb)
+ return math.sqrt(snr / (1 + snr))
+
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index 7a98fa3da14a..3e9968d47fdd 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -6,11 +6,18 @@
import re
import warnings
from contextlib import contextmanager
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from functools import partial
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import numpy as np
import torch
+
+if getattr(torch, "distributed", None) is not None:
+ from torch.distributed.fsdp import CPUOffload, ShardingStrategy
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
+
from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
from .schedulers import SchedulerMixin
@@ -18,6 +25,7 @@
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
+ is_accelerate_available,
is_peft_available,
is_torch_npu_available,
is_torchvision_available,
@@ -31,6 +39,9 @@
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
+if is_accelerate_available():
+ from accelerate.logging import get_logger
+
if is_peft_available():
from peft import set_peft_model_state_dict
@@ -394,6 +405,86 @@ def find_nearest_bucket(h, w, bucket_options):
return best_bucket_idx
+def _to_cpu_contiguous(state_dicts) -> dict:
+ return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()}
+
+
+def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
+ """
+ Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
+ """
+
+ kwargs = {}
+ fsdp_state = getattr(accelerator.state, "fsdp_plugin", None)
+
+ if fsdp_state is None:
+ raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.")
+
+ fsdp_plugin = accelerator.state.fsdp_plugin
+
+ if fsdp_plugin is None:
+ # FSDP not enabled in Accelerator
+ kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
+ else:
+ # FSDP is enabled → use plugin's strategy, or default if None
+ kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
+
+ return kwargs
+
+
+def wrap_with_fsdp(
+ model: torch.nn.Module,
+ device: Union[str, torch.device],
+ offload: bool = True,
+ use_orig_params: bool = True,
+ limit_all_gathers: bool = True,
+ fsdp_kwargs: Optional[Dict[str, Any]] = None,
+ transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
+) -> FSDP:
+ """
+ Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
+
+ Args:
+ model: Model to wrap
+ device: Target device (e.g., accelerator.device)
+ offload: Whether to enable CPU parameter offloading
+ use_orig_params: Whether to use original parameters
+ limit_all_gathers: Whether to limit all gathers
+ fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
+ transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
+
+ Returns:
+ FSDP-wrapped model
+ """
+
+ logger = get_logger(__name__)
+
+ if transformer_layer_cls is None:
+ # Set the default layers if transformer_layer_cls is not provided
+ transformer_layer_cls = type(model.model.language_model.layers[0])
+ logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}")
+
+ # Add auto-wrap policy if transformer layers specified
+ auto_wrap_policy = partial(
+ transformer_auto_wrap_policy,
+ transformer_layer_cls={transformer_layer_cls},
+ )
+
+ config = {
+ "device_id": device,
+ "cpu_offload": CPUOffload(offload_params=offload) if offload else None,
+ "use_orig_params": use_orig_params,
+ "limit_all_gathers": limit_all_gathers,
+ "auto_wrap_policy": auto_wrap_policy,
+ }
+
+ if fsdp_kwargs:
+ config.update(fsdp_kwargs)
+
+ fsdp_model = FSDP(model, **config)
+ return fsdp_model
+
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 440e4539e720..e726bbb46913 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -66,6 +66,7 @@
is_accelerate_version,
is_aiter_available,
is_aiter_version,
+ is_av_available,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
diff --git a/src/diffusers/utils/distributed_utils.py b/src/diffusers/utils/distributed_utils.py
new file mode 100644
index 000000000000..239b7b26200d
--- /dev/null
+++ b/src/diffusers/utils/distributed_utils.py
@@ -0,0 +1,36 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def is_torch_dist_rank_zero() -> bool:
+ if torch is None:
+ return True
+
+ dist_module = getattr(torch, "distributed", None)
+ if dist_module is None or not dist_module.is_available():
+ return True
+
+ if not dist_module.is_initialized():
+ return True
+
+ try:
+ return dist_module.get_rank() == 0
+ except (RuntimeError, ValueError):
+ return True
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index eb956a96a30f..7120ff1f6257 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -502,6 +502,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLLTX2Audio(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLLTX2Video(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKLLTXVideo(metaclass=DummyObject):
_backends = ["torch"]
@@ -952,6 +982,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class GlmImageTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HiDreamImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1147,6 +1192,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class LTX2VideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class LTXVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -2634,6 +2694,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class LTXEulerAncestralRFScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 6c28e87581b9..63f381419fda 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -167,6 +167,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class QwenImageLayeredAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageLayeredModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class QwenImageModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -557,6 +587,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class BriaFiboEditPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class BriaFiboPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -602,6 +647,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class ChromaInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class ChromaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -902,6 +962,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class Flux2KleinPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class Flux2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1112,6 +1187,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class GlmImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class HiDreamImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1877,6 +1967,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class LTX2ImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class LTX2LatentUpsamplePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class LTX2Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1892,6 +2027,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class LTXI2VLongMultiPromptPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class LTXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index 57b0a337922a..425c360a3110 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -230,6 +230,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
+_av_available, _av_version = _is_package_available("av")
def is_torch_available():
@@ -420,6 +421,10 @@ def is_kornia_available():
return _kornia_available
+def is_av_available():
+ return _av_available
+
+
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py
index 2ad6d3a47607..80e108e4a6ff 100644
--- a/src/diffusers/utils/logging.py
+++ b/src/diffusers/utils/logging.py
@@ -32,6 +32,8 @@
from tqdm import auto as tqdm_lib
+from .distributed_utils import is_torch_dist_rank_zero
+
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
@@ -47,6 +49,23 @@
_default_log_level = logging.WARNING
_tqdm_active = True
+_rank_zero_filter = None
+
+
+class _RankZeroFilter(logging.Filter):
+ def filter(self, record):
+ # Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
+ return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG
+
+
+def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
+ global _rank_zero_filter
+
+ if _rank_zero_filter is None:
+ _rank_zero_filter = _RankZeroFilter()
+
+ if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
+ logger.addFilter(_rank_zero_filter)
def _get_default_logging_level() -> int:
@@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
+ _ensure_rank_zero_filter(library_root_logger)
def _reset_library_root_logger() -> None:
@@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
name = _get_library_name()
_configure_library_root_logger()
- return logging.getLogger(name)
+ logger = logging.getLogger(name)
+ _ensure_rank_zero_filter(logger)
+ return logger
def get_verbosity() -> int:
diff --git a/tests/conftest.py b/tests/conftest.py
index fd76d1c84ee7..9558c23d3062 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -32,6 +32,8 @@
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
+ config.addinivalue_line("markers", "slow: mark test as slow")
+ config.addinivalue_line("markers", "nightly: mark test as nightly")
def pytest_addoption(parser):
diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py
index 91f63c4b56c4..78ef4ce151be 100644
--- a/tests/lora/test_lora_layers_auraflow.py
+++ b/tests/lora/test_lora_layers_auraflow.py
@@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -114,23 +116,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass
-
- @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py
index fa57b4c9c2f9..7bd54b77ca35 100644
--- a/tests/lora/test_lora_layers_cogvideox.py
+++ b/tests/lora/test_lora_layers_cogvideox.py
@@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 9, 16, 16, 3)
@@ -147,26 +149,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass
- @unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
-
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py
index 30eb8fbb6367..e8ee6e7a7db6 100644
--- a/tests/lora/test_lora_layers_cogview4.py
+++ b/tests/lora/test_lora_layers_cogview4.py
@@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder",
)
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -162,23 +164,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in CogView4.")
def test_modify_padding_mode(self):
pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogView4.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogView4.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogView4.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogView4.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in CogView4.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py
index 4ae189aceb66..d970b7d7847f 100644
--- a/tests/lora/test_lora_layers_flux2.py
+++ b/tests/lora/test_lora_layers_flux2.py
@@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -146,23 +148,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in Flux2.")
def test_modify_padding_mode(self):
pass
-
- @unittest.skip("Text encoder LoRA is not supported in Flux2.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Flux2.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Flux2.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Flux2.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Flux2.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
index cfd5d3146a91..e59bc5662fe1 100644
--- a/tests/lora/test_lora_layers_hunyuanvideo.py
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder_2",
)
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -172,26 +174,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass
- @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
-
@nightly
@require_torch_accelerator
diff --git a/tests/lora/test_lora_layers_ltx2.py b/tests/lora/test_lora_layers_ltx2.py
new file mode 100644
index 000000000000..0a4b14454f5b
--- /dev/null
+++ b/tests/lora/test_lora_layers_ltx2.py
@@ -0,0 +1,271 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
+
+from diffusers import (
+ AutoencoderKLLTX2Audio,
+ AutoencoderKLLTX2Video,
+ FlowMatchEulerDiscreteScheduler,
+ LTX2Pipeline,
+ LTX2VideoTransformer3DModel,
+)
+from diffusers.pipelines.ltx2 import LTX2TextConnectors
+from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
+from diffusers.utils.import_utils import is_peft_available
+
+from ..testing_utils import floats_tensor, require_peft_backend
+
+
+if is_peft_available():
+ from peft import LoraConfig
+
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = LTX2Pipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "cross_attention_dim": 16,
+ "audio_in_channels": 4,
+ "audio_out_channels": 4,
+ "audio_num_attention_heads": 2,
+ "audio_attention_head_dim": 4,
+ "audio_cross_attention_dim": 8,
+ "num_layers": 1,
+ "qk_norm": "rms_norm_across_heads",
+ "caption_channels": 32,
+ "rope_double_precision": False,
+ "rope_type": "split",
+ }
+ transformer_cls = LTX2VideoTransformer3DModel
+
+ vae_kwargs = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 4,
+ "block_out_channels": (8,),
+ "decoder_block_out_channels": (8,),
+ "layers_per_block": (1,),
+ "decoder_layers_per_block": (1, 1),
+ "spatio_temporal_scaling": (True,),
+ "decoder_spatio_temporal_scaling": (True,),
+ "decoder_inject_noise": (False, False),
+ "downsample_type": ("spatial",),
+ "upsample_residual": (False,),
+ "upsample_factor": (1,),
+ "timestep_conditioning": False,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ }
+ vae_cls = AutoencoderKLLTX2Video
+
+ audio_vae_kwargs = {
+ "base_channels": 4,
+ "output_channels": 2,
+ "ch_mult": (1,),
+ "num_res_blocks": 1,
+ "attn_resolutions": None,
+ "in_channels": 2,
+ "resolution": 32,
+ "latent_channels": 2,
+ "norm_type": "pixel",
+ "causality_axis": "height",
+ "dropout": 0.0,
+ "mid_block_add_attention": False,
+ "sample_rate": 16000,
+ "mel_hop_length": 160,
+ "is_causal": True,
+ "mel_bins": 8,
+ }
+ audio_vae_cls = AutoencoderKLLTX2Audio
+
+ vocoder_kwargs = {
+ "in_channels": 16, # output_channels * mel_bins = 2 * 8
+ "hidden_channels": 32,
+ "out_channels": 2,
+ "upsample_kernel_sizes": [4, 4],
+ "upsample_factors": [2, 2],
+ "resnet_kernel_sizes": [3],
+ "resnet_dilations": [[1, 3, 5]],
+ "leaky_relu_negative_slope": 0.1,
+ "output_sampling_rate": 16000,
+ }
+ vocoder_cls = LTX2Vocoder
+
+ connectors_kwargs = {
+ "caption_channels": 32, # Will be set dynamically from text_encoder
+ "text_proj_in_factor": 2, # Will be set dynamically from text_encoder
+ "video_connector_num_attention_heads": 4,
+ "video_connector_attention_head_dim": 8,
+ "video_connector_num_layers": 1,
+ "video_connector_num_learnable_registers": None,
+ "audio_connector_num_attention_heads": 4,
+ "audio_connector_attention_head_dim": 8,
+ "audio_connector_num_layers": 1,
+ "audio_connector_num_learnable_registers": None,
+ "connector_rope_base_seq_len": 32,
+ "rope_theta": 10000.0,
+ "rope_double_precision": False,
+ "causal_temporal_positioning": False,
+ "rope_type": "split",
+ }
+ connectors_cls = LTX2TextConnectors
+
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-gemma3"
+ text_encoder_cls, text_encoder_id = (
+ Gemma3ForConditionalGeneration,
+ "hf-internal-testing/tiny-gemma3",
+ )
+
+ denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
+
+ supports_text_encoder_loras = False
+
+ @property
+ def output_shape(self):
+ return (1, 5, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 5
+ num_latent_frames = 2
+ latent_height = 8
+ latent_width = 8
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "a robot dancing",
+ "num_frames": num_frames,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "frame_rate": 25.0,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
+ # Override to instantiate LTX2-specific components (connectors, audio_vae, vocoder)
+ torch.manual_seed(0)
+ text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
+ tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
+
+ # Update caption_channels and text_proj_in_factor based on text_encoder config
+ transformer_kwargs = self.transformer_kwargs.copy()
+ transformer_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
+
+ connectors_kwargs = self.connectors_kwargs.copy()
+ connectors_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
+ connectors_kwargs["text_proj_in_factor"] = text_encoder.config.text_config.num_hidden_layers + 1
+
+ torch.manual_seed(0)
+ transformer = self.transformer_cls(**transformer_kwargs)
+
+ torch.manual_seed(0)
+ vae = self.vae_cls(**self.vae_kwargs)
+ vae.use_framewise_encoding = False
+ vae.use_framewise_decoding = False
+
+ torch.manual_seed(0)
+ audio_vae = self.audio_vae_cls(**self.audio_vae_kwargs)
+
+ torch.manual_seed(0)
+ vocoder = self.vocoder_cls(**self.vocoder_kwargs)
+
+ torch.manual_seed(0)
+ connectors = self.connectors_cls(**connectors_kwargs)
+
+ if scheduler_cls is None:
+ scheduler_cls = self.scheduler_cls
+ scheduler = scheduler_cls(**self.scheduler_kwargs)
+
+ rank = 4
+ lora_alpha = rank if lora_alpha is None else lora_alpha
+
+ text_lora_config = LoraConfig(
+ r=rank,
+ lora_alpha=lora_alpha,
+ target_modules=self.text_encoder_target_modules,
+ init_lora_weights=False,
+ use_dora=use_dora,
+ )
+
+ denoiser_lora_config = LoraConfig(
+ r=rank,
+ lora_alpha=lora_alpha,
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
+ init_lora_weights=False,
+ use_dora=use_dora,
+ )
+
+ pipeline_components = {
+ "transformer": transformer,
+ "vae": vae,
+ "audio_vae": audio_vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "connectors": connectors,
+ "vocoder": vocoder,
+ }
+
+ return pipeline_components, text_lora_config, denoiser_lora_config
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in LTX2.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in LTX2.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in LTX2.")
+ def test_modify_padding_mode(self):
+ pass
diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py
index 6ab51a5e513f..095e5b577cf0 100644
--- a/tests/lora/test_lora_layers_ltx_video.py
+++ b/tests/lora/test_lora_layers_ltx_video.py
@@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -125,23 +127,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in LTXVideo.")
def test_modify_padding_mode(self):
pass
-
- @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py
index 0417b05b33a1..da032229a785 100644
--- a/tests/lora/test_lora_layers_lumina2.py
+++ b/tests/lora/test_lora_layers_lumina2.py
@@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 4, 4, 3)
@@ -113,26 +115,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass
- @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
-
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py
index 7be81273db77..ee8254112924 100644
--- a/tests/lora/test_lora_layers_mochi.py
+++ b/tests/lora/test_lora_layers_mochi.py
@@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 7, 16, 16, 3)
@@ -117,26 +119,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass
- @unittest.skip("Text encoder LoRA is not supported in Mochi.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Mochi.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Mochi.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Mochi.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Mochi.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
-
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py
index 51de2f8e20e1..73fd026a670c 100644
--- a/tests/lora/test_lora_layers_qwenimage.py
+++ b/tests/lora/test_lora_layers_qwenimage.py
@@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
)
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -107,23 +109,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
-
- @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py
index a860b7b44f2c..97bf5cbba920 100644
--- a/tests/lora/test_lora_layers_sana.py
+++ b/tests/lora/test_lora_layers_sana.py
@@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -117,26 +119,6 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
- @unittest.skip("Text encoder LoRA is not supported in SANA.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in SANA.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in SANA.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in SANA.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in SANA.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
-
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference_denoiser(self):
return super().test_layerwise_casting_inference_denoiser()
diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py
index 5734509b410f..5ae16ab4b9da 100644
--- a/tests/lora/test_lora_layers_wan.py
+++ b/tests/lora/test_lora_layers_wan.py
@@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -121,23 +123,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in Wan.")
def test_modify_padding_mode(self):
pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py
index ab1f57bfc9da..c8acaea9bef0 100644
--- a/tests/lora/test_lora_layers_wanvace.py
+++ b/tests/lora/test_lora_layers_wanvace.py
@@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 9, 16, 16, 3)
@@ -139,26 +141,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass
- @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
-
def test_layerwise_casting_inference_denoiser(self):
super().test_layerwise_casting_inference_denoiser()
diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py
index 35d1389d9612..8432ea56a6fa 100644
--- a/tests/lora/test_lora_layers_z_image.py
+++ b/tests/lora/test_lora_layers_z_image.py
@@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ supports_text_encoder_loras = False
+
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -263,23 +265,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in ZImage.")
def test_modify_padding_mode(self):
pass
-
- @unittest.skip("Text encoder LoRA is not supported in ZImage.")
- def test_simple_inference_with_partial_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in ZImage.")
- def test_simple_inference_with_text_lora(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in ZImage.")
- def test_simple_inference_with_text_lora_and_scale(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in ZImage.")
- def test_simple_inference_with_text_lora_fused(self):
- pass
-
- @unittest.skip("Text encoder LoRA is not supported in ZImage.")
- def test_simple_inference_with_text_lora_save_load(self):
- pass
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 5fae6cac0a7f..efa49b9f4838 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests:
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
+ supports_text_encoder_loras = True
unet_kwargs = None
transformer_cls = None
@@ -333,6 +334,9 @@ def test_simple_inference_with_text_lora(self):
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
+ if not self.supports_text_encoder_loras:
+ pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
+
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -457,6 +461,9 @@ def test_simple_inference_with_text_lora_and_scale(self):
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
+ if not self.supports_text_encoder_loras:
+ pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
+
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -494,6 +501,9 @@ def test_simple_inference_with_text_lora_fused(self):
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
+ if not self.supports_text_encoder_loras:
+ pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
+
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -555,6 +565,9 @@ def test_simple_inference_with_text_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
+ if not self.supports_text_encoder_loras:
+ pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
+
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -593,6 +606,9 @@ def test_simple_inference_with_partial_text_lora(self):
with different ranks and some adapters removed
and makes sure it works as expected
"""
+ if not self.supports_text_encoder_loras:
+ pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
+
components, _, _ = self.get_dummy_components()
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
@@ -651,6 +667,9 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
+ if not self.supports_text_encoder_loras:
+ pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
+
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py
new file mode 100644
index 000000000000..ce93dfb42afe
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLLTX2Audio
+
+from ...testing_utils import (
+ floats_tensor,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
+
+
+class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLLTX2Audio
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_ltx_video_config(self):
+ return {
+ "in_channels": 2, # stereo,
+ "output_channels": 2,
+ "latent_channels": 4,
+ "base_channels": 16,
+ "ch_mult": (1, 2, 4),
+ "resolution": 16,
+ "attn_resolutions": None,
+ "num_res_blocks": 2,
+ "norm_type": "pixel",
+ "causality_axis": "height",
+ "mid_block_add_attention": False,
+ "sample_rate": 16000,
+ "mel_hop_length": 160,
+ "mel_bins": 16,
+ "is_causal": True,
+ "double_z": True,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 2
+ num_frames = 8
+ num_mel_bins = 16
+
+ spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
+
+ input_dict = {"sample": spectrogram}
+ return input_dict
+
+ @property
+ def input_shape(self):
+ return (2, 5, 16)
+
+ @property
+ def output_shape(self):
+ return (2, 5, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_ltx_video_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ # Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
+ def test_output(self):
+ super().test_output(expected_output_shape=(2, 2, 5, 16))
+
+ @unittest.skip("Unsupported test.")
+ def test_outputs_equivalence(self):
+ pass
+
+ @unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.")
+ def test_forward_with_norm_groups(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
new file mode 100644
index 000000000000..146241361a82
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
@@ -0,0 +1,103 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLLTX2Video
+
+from ...testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLLTX2Video
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_ltx_video_config(self):
+ return {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 8,
+ "block_out_channels": (8, 8, 8, 8),
+ "decoder_block_out_channels": (16, 32, 64),
+ "layers_per_block": (1, 1, 1, 1, 1),
+ "decoder_layers_per_block": (1, 1, 1, 1),
+ "spatio_temporal_scaling": (True, True, True, True),
+ "decoder_spatio_temporal_scaling": (True, True, True),
+ "decoder_inject_noise": (False, False, False, False),
+ "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
+ "upsample_residual": (True, True, True),
+ "upsample_factor": (2, 2, 2),
+ "timestep_conditioning": False,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ "encoder_spatial_padding_mode": "zeros",
+ # Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
+ "decoder_spatial_padding_mode": "zeros",
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ input_dict = {"sample": image}
+ return input_dict
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_ltx_video_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "LTX2VideoEncoder3d",
+ "LTX2VideoDecoder3d",
+ "LTX2VideoDownBlock3D",
+ "LTX2VideoMidBlock3d",
+ "LTX2VideoUpBlock3d",
+ }
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Unsupported test.")
+ def test_outputs_equivalence(self):
+ pass
+
+ @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
+ def test_forward_with_norm_groups(self):
+ pass
diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py
new file mode 100644
index 000000000000..af9ef0623891
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_ltx2.py
@@ -0,0 +1,222 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import LTX2VideoTransformer3DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = LTX2VideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ # Common
+ batch_size = 2
+
+ # Video
+ num_frames = 2
+ num_channels = 4
+ height = 16
+ width = 16
+
+ # Audio
+ audio_num_frames = 9
+ audio_num_channels = 2
+ num_mel_bins = 2
+
+ # Text
+ embedding_dim = 16
+ sequence_length = 16
+
+ hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
+ audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
+ torch_device
+ )
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
+ timestep = torch.rand((batch_size,)).to(torch_device) * 1000
+
+ return {
+ "hidden_states": hidden_states,
+ "audio_hidden_states": audio_hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "audio_encoder_hidden_states": audio_encoder_hidden_states,
+ "timestep": timestep,
+ "encoder_attention_mask": encoder_attention_mask,
+ "num_frames": num_frames,
+ "height": height,
+ "width": width,
+ "audio_num_frames": audio_num_frames,
+ "fps": 25.0,
+ }
+
+ @property
+ def input_shape(self):
+ return (512, 4)
+
+ @property
+ def output_shape(self):
+ return (512, 4)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "cross_attention_dim": 16,
+ "audio_in_channels": 4,
+ "audio_out_channels": 4,
+ "audio_num_attention_heads": 2,
+ "audio_attention_head_dim": 4,
+ "audio_cross_attention_dim": 8,
+ "num_layers": 2,
+ "qk_norm": "rms_norm_across_heads",
+ "caption_channels": 16,
+ "rope_double_precision": False,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"LTX2VideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ # def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
+ # torch.manual_seed(seed)
+ # init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+
+ # # Calculate dummy inputs in a custom manner to ensure compatibility with original code
+ # batch_size = 2
+ # num_frames = 9
+ # latent_frames = 2
+ # text_embedding_dim = 16
+ # text_seq_len = 16
+ # fps = 25.0
+ # sampling_rate = 16000.0
+ # hop_length = 160.0
+
+ # sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
+ # timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
+
+ # num_channels = 4
+ # latent_height = 4
+ # latent_width = 4
+ # hidden_states = torch.randn(
+ # (batch_size, num_channels, latent_frames, latent_height, latent_width),
+ # generator=torch.manual_seed(seed),
+ # dtype=dtype,
+ # device="cpu",
+ # )
+ # # Patchify video latents (with patch_size (1, 1, 1))
+ # hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
+ # hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ # encoder_hidden_states = torch.randn(
+ # (batch_size, text_seq_len, text_embedding_dim),
+ # generator=torch.manual_seed(seed),
+ # dtype=dtype,
+ # device="cpu",
+ # )
+
+ # audio_num_channels = 2
+ # num_mel_bins = 2
+ # latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
+ # audio_hidden_states = torch.randn(
+ # (batch_size, audio_num_channels, latent_length, num_mel_bins),
+ # generator=torch.manual_seed(seed),
+ # dtype=dtype,
+ # device="cpu",
+ # )
+ # # Patchify audio latents
+ # audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
+ # audio_encoder_hidden_states = torch.randn(
+ # (batch_size, text_seq_len, text_embedding_dim),
+ # generator=torch.manual_seed(seed),
+ # dtype=dtype,
+ # device="cpu",
+ # )
+
+ # inputs_dict = {
+ # "hidden_states": hidden_states.to(device=torch_device),
+ # "audio_hidden_states": audio_hidden_states.to(device=torch_device),
+ # "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
+ # "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
+ # "timestep": timestep,
+ # "num_frames": latent_frames,
+ # "height": latent_height,
+ # "width": latent_width,
+ # "audio_num_frames": num_frames,
+ # "fps": 25.0,
+ # }
+
+ # model = self.model_class.from_pretrained(
+ # "diffusers-internal-dev/dummy-ltx2",
+ # subfolder="transformer",
+ # device_map="cpu",
+ # )
+ # # torch.manual_seed(seed)
+ # # model = self.model_class(**init_dict)
+ # model.to(torch_device)
+ # model.eval()
+
+ # with attention_backend("native"):
+ # with torch.no_grad():
+ # output = model(**inputs_dict)
+
+ # video_output, audio_output = output.to_tuple()
+
+ # self.assertIsNotNone(video_output)
+ # self.assertIsNotNone(audio_output)
+
+ # # input & output have to have the same shape
+ # video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
+ # self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
+ # audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
+ # self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
+
+ # # Check against expected slice
+ # # fmt: off
+ # video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
+ # audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
+ # # fmt: on
+
+ # video_output_flat = video_output.cpu().flatten().float()
+ # video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
+ # self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
+
+ # audio_output_flat = audio_output.cpu().flatten().float()
+ # audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
+ # self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
+
+
+class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = LTX2VideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py
index b24fa90503ef..384954dfbad7 100644
--- a/tests/models/transformers/test_models_transformer_qwenimage.py
+++ b/tests/models/transformers/test_models_transformer_qwenimage.py
@@ -15,10 +15,10 @@
import unittest
-import pytest
import torch
from diffusers import QwenImageTransformer2DModel
+from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
@@ -68,7 +68,6 @@ def prepare_dummy_input(self, height=4, width=4):
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
- "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
}
def prepare_init_args_and_inputs_for_common(self):
@@ -91,6 +90,180 @@ def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+ def test_infers_text_seq_len_from_mask(self):
+ """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
+ init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
+ encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
+ encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
+
+ rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
+ inputs["encoder_hidden_states"], encoder_hidden_states_mask
+ )
+
+ # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
+ self.assertIsInstance(rope_text_seq_len, int)
+
+ # Verify per_sample_len is computed correctly (max valid position + 1 = 2)
+ self.assertIsInstance(per_sample_len, torch.Tensor)
+ self.assertEqual(int(per_sample_len.max().item()), 2)
+
+ # Verify mask is normalized to bool dtype
+ self.assertTrue(normalized_mask.dtype == torch.bool)
+ self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
+
+ # Verify rope_text_seq_len is at least the sequence length
+ self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
+
+ # Test 2: Verify model runs successfully with inferred values
+ inputs["encoder_hidden_states_mask"] = normalized_mask
+ with torch.no_grad():
+ output = model(**inputs)
+ self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
+
+ # Test 3: Different mask pattern (padding at beginning)
+ encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
+ encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
+ encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
+
+ rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
+ inputs["encoder_hidden_states"], encoder_hidden_states_mask2
+ )
+
+ # Max valid position is 6 (last token), so per_sample_len should be 7
+ self.assertEqual(int(per_sample_len2.max().item()), 7)
+ self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
+
+ # Test 4: No mask provided (None case)
+ rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
+ inputs["encoder_hidden_states"], None
+ )
+ self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
+ self.assertIsInstance(rope_text_seq_len_none, int)
+ self.assertIsNone(per_sample_len_none)
+ self.assertIsNone(normalized_mask_none)
+
+ def test_non_contiguous_attention_mask(self):
+ """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
+ init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
+ encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
+ # Pattern: [True, False, True, False, True, False, False]
+ encoder_hidden_states_mask[:, 1] = 0
+ encoder_hidden_states_mask[:, 3] = 0
+ encoder_hidden_states_mask[:, 5:] = 0
+
+ inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
+ inputs["encoder_hidden_states"], encoder_hidden_states_mask
+ )
+ self.assertEqual(int(per_sample_len.max().item()), 5)
+ self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
+ self.assertIsInstance(inferred_rope_len, int)
+ self.assertTrue(normalized_mask.dtype == torch.bool)
+
+ inputs["encoder_hidden_states_mask"] = normalized_mask
+
+ with torch.no_grad():
+ output = model(**inputs)
+
+ self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
+
+ def test_txt_seq_lens_deprecation(self):
+ """Test that passing txt_seq_lens raises a deprecation warning."""
+ init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ # Prepare inputs with txt_seq_lens (deprecated parameter)
+ txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
+
+ # Remove encoder_hidden_states_mask to use the deprecated path
+ inputs_with_deprecated = inputs.copy()
+ inputs_with_deprecated.pop("encoder_hidden_states_mask")
+ inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
+
+ # Test that deprecation warning is raised
+ with self.assertWarns(FutureWarning) as warning_context:
+ with torch.no_grad():
+ output = model(**inputs_with_deprecated)
+
+ # Verify the warning message mentions the deprecation
+ warning_message = str(warning_context.warning)
+ self.assertIn("txt_seq_lens", warning_message)
+ self.assertIn("deprecated", warning_message)
+ self.assertIn("encoder_hidden_states_mask", warning_message)
+
+ # Verify the model still works correctly despite the deprecation
+ self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
+
+ def test_layered_model_with_mask(self):
+ """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
+ # Create layered model config
+ init_dict = {
+ "patch_size": 2,
+ "in_channels": 16,
+ "out_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 16,
+ "num_attention_heads": 3,
+ "joint_attention_dim": 16,
+ "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
+ "use_layer3d_rope": True, # Enable layered RoPE
+ "use_additional_t_cond": True, # Enable additional time conditioning
+ }
+
+ model = self.model_class(**init_dict).to(torch_device)
+
+ # Verify the model uses QwenEmbedLayer3DRope
+ from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
+
+ self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
+
+ # Test single generation with layered structure
+ batch_size = 1
+ text_seq_len = 7
+ img_h, img_w = 4, 4
+ layers = 4
+
+ # For layered model: (layers + 1) because we have N layers + 1 combined image
+ hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
+ encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
+
+ # Create mask with some padding
+ encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
+ encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
+
+ timestep = torch.tensor([1.0]).to(torch_device)
+
+ # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
+ addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
+
+ # Layer structure: 4 layers + 1 condition image
+ img_shapes = [
+ [
+ (1, img_h, img_w), # layer 0
+ (1, img_h, img_w), # layer 1
+ (1, img_h, img_w), # layer 2
+ (1, img_h, img_w), # layer 3
+ (1, img_h, img_w), # condition image (last one gets special treatment)
+ ]
+ ]
+
+ with torch.no_grad():
+ output = model(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ timestep=timestep,
+ img_shapes=img_shapes,
+ additional_t_cond=addition_t_cond,
+ )
+
+ self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
+
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
@@ -101,6 +274,5 @@ def prepare_init_args_and_inputs_for_common(self):
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
- @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py
new file mode 100644
index 000000000000..9c5fd5be326d
--- /dev/null
+++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py
@@ -0,0 +1,272 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import tempfile
+from collections import deque
+from typing import List
+
+import numpy as np
+import torch
+
+from diffusers import FluxTransformer2DModel
+from diffusers.modular_pipelines import (
+ ComponentSpec,
+ InputParam,
+ ModularPipelineBlocks,
+ OutputParam,
+ PipelineState,
+ WanModularPipeline,
+)
+
+from ..testing_utils import nightly, require_torch, slow
+
+
+class DummyCustomBlockSimple(ModularPipelineBlocks):
+ def __init__(self, use_dummy_model_component=False):
+ self.use_dummy_model_component = use_dummy_model_component
+ super().__init__()
+
+ @property
+ def expected_components(self):
+ if self.use_dummy_model_component:
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
+ else:
+ return []
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return []
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "output_prompt",
+ type_hint=str,
+ description="Modified prompt",
+ )
+ ]
+
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ old_prompt = block_state.prompt
+ block_state.output_prompt = "Modular diffusers + " + old_prompt
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+CODE_STR = """
+from diffusers.modular_pipelines import (
+ ComponentSpec,
+ InputParam,
+ ModularPipelineBlocks,
+ OutputParam,
+ PipelineState,
+ WanModularPipeline,
+)
+from typing import List
+
+class DummyCustomBlockSimple(ModularPipelineBlocks):
+ def __init__(self, use_dummy_model_component=False):
+ self.use_dummy_model_component = use_dummy_model_component
+ super().__init__()
+
+ @property
+ def expected_components(self):
+ if self.use_dummy_model_component:
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
+ else:
+ return []
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return []
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "output_prompt",
+ type_hint=str,
+ description="Modified prompt",
+ )
+ ]
+
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ old_prompt = block_state.prompt
+ block_state.output_prompt = "Modular diffusers + " + old_prompt
+ self.set_block_state(state, block_state)
+
+ return components, state
+"""
+
+
+class TestModularCustomBlocks:
+ def _test_block_properties(self, block):
+ assert not block.expected_components
+ assert not block.intermediate_inputs
+
+ actual_inputs = [inp.name for inp in block.inputs]
+ actual_intermediate_outputs = [out.name for out in block.intermediate_outputs]
+ assert actual_inputs == ["prompt"]
+ assert actual_intermediate_outputs == ["output_prompt"]
+
+ def test_custom_block_properties(self):
+ custom_block = DummyCustomBlockSimple()
+ self._test_block_properties(custom_block)
+
+ def test_custom_block_output(self):
+ custom_block = DummyCustomBlockSimple()
+ pipe = custom_block.init_pipeline()
+ prompt = "Diffusers is nice"
+ output = pipe(prompt=prompt)
+
+ actual_inputs = [inp.name for inp in custom_block.inputs]
+ actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs]
+ assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs)
+
+ output_prompt = output.values["output_prompt"]
+ assert output_prompt.startswith("Modular diffusers + ")
+
+ def test_custom_block_saving_loading(self):
+ custom_block = DummyCustomBlockSimple()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ custom_block.save_pretrained(tmpdir)
+ assert any("modular_config.json" in k for k in os.listdir(tmpdir))
+
+ with open(os.path.join(tmpdir, "modular_config.json"), "r") as f:
+ config = json.load(f)
+ auto_map = config["auto_map"]
+ assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
+
+ # For now, the Python script that implements the custom block has to be manually pushed to the Hub.
+ # This is why, we have to separately save the Python script here.
+ code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py")
+ with open(code_path, "w") as f:
+ f.write(CODE_STR)
+
+ loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True)
+
+ pipe = loaded_custom_block.init_pipeline()
+ prompt = "Diffusers is nice"
+ output = pipe(prompt=prompt)
+
+ actual_inputs = [inp.name for inp in loaded_custom_block.inputs]
+ actual_intermediate_outputs = [out.name for out in loaded_custom_block.intermediate_outputs]
+ assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs)
+
+ output_prompt = output.values["output_prompt"]
+ assert output_prompt.startswith("Modular diffusers + ")
+
+ def test_custom_block_supported_components(self):
+ custom_block = DummyCustomBlockSimple(use_dummy_model_component=True)
+ pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe")
+ pipe.load_components()
+
+ assert len(pipe.components) == 1
+ assert pipe.component_names[0] == "transformer"
+
+ def test_custom_block_loads_from_hub(self):
+ repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
+ block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
+ self._test_block_properties(block)
+
+ pipe = block.init_pipeline()
+
+ prompt = "Diffusers is nice"
+ output = pipe(prompt=prompt)
+ output_prompt = output.values["output_prompt"]
+ assert output_prompt.startswith("Modular diffusers + ")
+
+
+@slow
+@nightly
+@require_torch
+class TestKreaCustomBlocksIntegration:
+ repo_id = "krea/krea-realtime-video"
+
+ def test_loading_from_hub(self):
+ blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
+ block_names = sorted(blocks.sub_blocks)
+
+ assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"])
+
+ pipe = WanModularPipeline(blocks, self.repo_id)
+ pipe.load_components(
+ trust_remote_code=True,
+ device_map="cuda",
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
+ )
+ assert len(pipe.components) == 7
+ assert sorted(pipe.components) == sorted(
+ ["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"]
+ )
+
+ def test_forward(self):
+ blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
+ pipe = WanModularPipeline(blocks, self.repo_id)
+ pipe.load_components(
+ trust_remote_code=True,
+ device_map="cuda",
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
+ )
+
+ num_frames_per_block = 2
+ num_blocks = 2
+
+ state = PipelineState()
+ state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))
+
+ prompt = ["a cat sitting on a boat"]
+
+ for block in pipe.transformer.blocks:
+ block.self_attn.fuse_projections()
+
+ for block_idx in range(num_blocks):
+ state = pipe(
+ state,
+ prompt=prompt,
+ num_inference_steps=2,
+ num_blocks=num_blocks,
+ num_frames_per_block=num_frames_per_block,
+ block_idx=block_idx,
+ generator=torch.manual_seed(42),
+ )
+ current_frames = np.array(state.values["videos"][0])
+ current_frames_flat = current_frames.flatten()
+ actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist()
+
+ if block_idx == 0:
+ assert current_frames.shape == (5, 480, 832, 3)
+ expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193])
+ else:
+ assert current_frames.shape == (8, 480, 832, 3)
+ expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191])
+
+ assert np.allclose(actual_slices, expected_slices)
diff --git a/tests/pipelines/bria_fibo_edit/__init__.py b/tests/pipelines/bria_fibo_edit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py
new file mode 100644
index 000000000000..5376c4b5e03f
--- /dev/null
+++ b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py
@@ -0,0 +1,192 @@
+# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer
+from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
+
+from diffusers import (
+ AutoencoderKLWan,
+ BriaFiboEditPipeline,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
+from tests.pipelines.test_pipelines_common import PipelineTesterMixin
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = BriaFiboEditPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale"])
+ batch_params = frozenset(["prompt"])
+ test_xformers_attention = False
+ test_layerwise_casting = False
+ test_group_offloading = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = BriaFiboTransformer2DModel(
+ patch_size=1,
+ in_channels=16,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=2,
+ joint_attention_dim=64,
+ text_encoder_dim=32,
+ pooled_projection_dim=None,
+ axes_dims_rope=[0, 4, 4],
+ )
+
+ vae = AutoencoderKLWan(
+ base_dim=80,
+ decoder_base_dim=128,
+ dim_mult=[1, 2, 4, 4],
+ dropout=0.0,
+ in_channels=12,
+ latents_mean=[0.0] * 16,
+ latents_std=[1.0] * 16,
+ is_residual=True,
+ num_res_blocks=2,
+ out_channels=12,
+ patch_size=2,
+ scale_factor_spatial=16,
+ scale_factor_temporal=4,
+ temperal_downsample=[False, True, True],
+ z_dim=16,
+ )
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+ inputs = {
+ "prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}',
+ "negative_prompt": "bad, ugly",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 192,
+ "width": 336,
+ "output_type": "np",
+ }
+ image = Image.new("RGB", (336, 192), (255, 255, 255))
+ inputs["image"] = image
+ return inputs
+
+ @unittest.skip(reason="will not be supported due to dim-fusion")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip(reason="Batching is not supported yet")
+ def test_num_images_per_prompt(self):
+ pass
+
+ @unittest.skip(reason="Batching is not supported yet")
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(reason="Batching is not supported yet")
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_bria_fibo_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = {"edit_instruction": "a different prompt"}
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ assert max_diff > 1e-6
+
+ def test_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (64, 64), (32, 64)]
+ for height, width in height_width_pairs:
+ expected_height = height
+ expected_width = width
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_bria_fibo_edit_mask(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L")
+
+ inputs.update({"mask": mask})
+ output = pipe(**inputs).images[0]
+
+ assert output.shape == (192, 336, 3)
+
+ def test_bria_fibo_edit_mask_image_size_mismatch(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L")
+
+ inputs.update({"mask": mask})
+ with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"):
+ pipe(**inputs)
+
+ def test_bria_fibo_edit_mask_no_image(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L")
+
+ # Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs)
+ inputs.pop("image", None)
+ inputs.update({"mask": mask})
+
+ with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"):
+ pipe(**inputs)
diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py
new file mode 100644
index 000000000000..8ed9bf3d1e91
--- /dev/null
+++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py
@@ -0,0 +1,183 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
+
+from diffusers import (
+ AutoencoderKLFlux2,
+ FlowMatchEulerDiscreteScheduler,
+ Flux2KleinPipeline,
+ Flux2Transformer2DModel,
+)
+
+from ...testing_utils import torch_device
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
+
+
+class Flux2KleinPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Flux2KleinPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ supports_dduf = False
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = Flux2Transformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=16,
+ timestep_guidance_channels=256,
+ axes_dims_rope=[4, 4, 4, 4],
+ guidance_embeds=False,
+ )
+
+ # Create minimal Qwen3 config
+ config = Qwen3Config(
+ intermediate_size=16,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ vocab_size=151936,
+ max_position_embeddings=512,
+ )
+ torch.manual_seed(0)
+ text_encoder = Qwen3ForCausalLM(config)
+
+ # Use a simple tokenizer for testing
+ tokenizer = Qwen2TokenizerFast.from_pretrained(
+ "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLFlux2(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "a dog is dancing",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 64,
+ "output_type": "np",
+ "text_encoder_out_layers": (1,),
+ }
+ return inputs
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ pipe.transformer.fuse_qkv_projections()
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
+ ("Fusion of QKV projections shouldn't affect the outputs."),
+ )
+ self.assertTrue(
+ np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
+ ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
+ )
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
+ ("Original outputs should match when fused QKV projections are disabled."),
+ )
+
+ def test_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ self.assertEqual(
+ (output_height, output_width),
+ (expected_height, expected_width),
+ f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
+ )
+
+ def test_image_input(self):
+ device = "cpu"
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
+ inputs = self.get_dummy_inputs(device)
+
+ inputs["image"] = Image.new("RGB", (64, 64))
+ image = pipe(**inputs).images.flatten()
+ generated_slice = np.concatenate([image[:8], image[-8:]])
+ # fmt: off
+ expected_slice = np.array(
+ [
+ 0.8255048 , 0.66054785, 0.6643694 , 0.67462724, 0.5494932 , 0.3480271 , 0.52535003, 0.44510138, 0.23549396, 0.21372932, 0.21166152, 0.63198495, 0.49942136, 0.39147034, 0.49156153, 0.3713916
+ ]
+ )
+ # fmt: on
+ assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)
+
+ @unittest.skip("Needs to be revisited")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/glm_image/__init__.py b/tests/pipelines/glm_image/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py
new file mode 100644
index 000000000000..7a380b99b0fb
--- /dev/null
+++ b/tests/pipelines/glm_image/test_glm_image.py
@@ -0,0 +1,227 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, require_transformers_version_greater
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+if is_transformers_version(">=", "5.0.0.dev0"):
+ from transformers import GlmImageConfig, GlmImageForConditionalGeneration, GlmImageProcessor
+
+
+enable_full_determinism()
+
+
+@require_transformers_version_greater("4.57.4")
+@require_torch_accelerator
+class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = GlmImagePipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_attention_slicing = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ glm_config = GlmImageConfig(
+ text_config={
+ "vocab_size": 168064,
+ "hidden_size": 32,
+ "intermediate_size": 32,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "max_position_embeddings": 512,
+ "vision_vocab_size": 128,
+ "rope_parameters": {"mrope_section": (4, 2, 2)},
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 32,
+ "num_heads": 2,
+ "image_size": 32,
+ "patch_size": 8,
+ "intermediate_size": 32,
+ },
+ vq_config={"embed_dim": 32, "num_embeddings": 128, "latent_channels": 32},
+ )
+
+ torch.manual_seed(0)
+ vision_language_encoder = GlmImageForConditionalGeneration(glm_config)
+
+ processor = GlmImageProcessor.from_pretrained("zai-org/GLM-Image", subfolder="processor")
+
+ torch.manual_seed(0)
+ # For GLM-Image, the relationship between components must satisfy:
+ # patch_size × vae_scale_factor = 16 (since AR tokens are upsampled 2× from d32)
+ transformer = GlmImageTransformer2DModel(
+ patch_size=2,
+ in_channels=4,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=8,
+ num_attention_heads=2,
+ text_embed_dim=text_encoder.config.hidden_size,
+ time_embed_dim=16,
+ condition_dim=8,
+ prior_vq_quantizer_codebook_size=128,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=(4, 8, 16, 16),
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=4,
+ sample_size=128,
+ latents_mean=[0.0] * 4,
+ latents_std=[1.0] * 4,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ components = {
+ "tokenizer": tokenizer,
+ "processor": processor,
+ "text_encoder": text_encoder,
+ "vision_language_encoder": vision_language_encoder,
+ "vae": vae,
+ "transformer": transformer,
+ "scheduler": scheduler,
+ }
+
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ height, width = 32, 32
+
+ inputs = {
+ "prompt": "A photo of a cat",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.5,
+ "height": height,
+ "width": width,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images[0]
+ generated_slice = image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
+
+ # fmt: off
+ expected_slice = np.array(
+ [
+ 0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728
+ ]
+ )
+ # fmt: on
+
+ self.assertEqual(image.shape, (3, 32, 32))
+ self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4))
+
+ @unittest.skip("Not supported.")
+ def test_inference_batch_single_identical(self):
+ # GLM-Image has batch_size=1 constraint due to AR model
+ pass
+
+ @unittest.skip("Not supported.")
+ def test_inference_batch_consistent(self):
+ # GLM-Image has batch_size=1 constraint due to AR model
+ pass
+
+ @unittest.skip("Not supported.")
+ def test_num_images_per_prompt(self):
+ # GLM-Image has batch_size=1 constraint due to AR model
+ pass
+
+ @unittest.skip("Needs to be revisited.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("Needs to be revisited.")
+ def test_pipeline_level_group_offloading_inference(self):
+ pass
+
+ @unittest.skip(
+ "Follow set of tests are relaxed because this pipeline doesn't guarantee same outputs for the same inputs in consecutive runs."
+ )
+ def test_dict_tuple_outputs_equivalent(self):
+ pass
+
+ @unittest.skip("Skipped")
+ def test_cpu_offload_forward_pass_twice(self):
+ pass
+
+ @unittest.skip("Skipped")
+ def test_sequential_offload_forward_pass_twice(self):
+ pass
+
+ @unittest.skip("Skipped")
+ def test_float16_inference(self):
+ pass
+
+ @unittest.skip("Skipped")
+ def test_save_load_float16(self):
+ pass
+
+ @unittest.skip("Skipped")
+ def test_save_load_local(self):
+ pass
diff --git a/tests/pipelines/ltx2/__init__.py b/tests/pipelines/ltx2/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py
new file mode 100644
index 000000000000..6ffc23725022
--- /dev/null
+++ b/tests/pipelines/ltx2/test_ltx2.py
@@ -0,0 +1,239 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
+
+from diffusers import (
+ AutoencoderKLLTX2Audio,
+ AutoencoderKLLTX2Video,
+ FlowMatchEulerDiscreteScheduler,
+ LTX2Pipeline,
+ LTX2VideoTransformer3DModel,
+)
+from diffusers.pipelines.ltx2 import LTX2TextConnectors
+from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
+
+from ...testing_utils import enable_full_determinism
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = LTX2Pipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "audio_latents",
+ "output_type",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_attention_slicing = False
+ test_xformers_attention = False
+ supports_dduf = False
+
+ base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
+
+ def get_dummy_components(self):
+ tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
+ text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
+
+ torch.manual_seed(0)
+ transformer = LTX2VideoTransformer3DModel(
+ in_channels=4,
+ out_channels=4,
+ patch_size=1,
+ patch_size_t=1,
+ num_attention_heads=2,
+ attention_head_dim=8,
+ cross_attention_dim=16,
+ audio_in_channels=4,
+ audio_out_channels=4,
+ audio_num_attention_heads=2,
+ audio_attention_head_dim=4,
+ audio_cross_attention_dim=8,
+ num_layers=2,
+ qk_norm="rms_norm_across_heads",
+ caption_channels=text_encoder.config.text_config.hidden_size,
+ rope_double_precision=False,
+ rope_type="split",
+ )
+
+ torch.manual_seed(0)
+ connectors = LTX2TextConnectors(
+ caption_channels=text_encoder.config.text_config.hidden_size,
+ text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
+ video_connector_num_attention_heads=4,
+ video_connector_attention_head_dim=8,
+ video_connector_num_layers=1,
+ video_connector_num_learnable_registers=None,
+ audio_connector_num_attention_heads=4,
+ audio_connector_attention_head_dim=8,
+ audio_connector_num_layers=1,
+ audio_connector_num_learnable_registers=None,
+ connector_rope_base_seq_len=32,
+ rope_theta=10000.0,
+ rope_double_precision=False,
+ causal_temporal_positioning=False,
+ rope_type="split",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLLTX2Video(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ block_out_channels=(8,),
+ decoder_block_out_channels=(8,),
+ layers_per_block=(1,),
+ decoder_layers_per_block=(1, 1),
+ spatio_temporal_scaling=(True,),
+ decoder_spatio_temporal_scaling=(True,),
+ decoder_inject_noise=(False, False),
+ downsample_type=("spatial",),
+ upsample_residual=(False,),
+ upsample_factor=(1,),
+ timestep_conditioning=False,
+ patch_size=1,
+ patch_size_t=1,
+ encoder_causal=True,
+ decoder_causal=False,
+ )
+ vae.use_framewise_encoding = False
+ vae.use_framewise_decoding = False
+
+ torch.manual_seed(0)
+ audio_vae = AutoencoderKLLTX2Audio(
+ base_channels=4,
+ output_channels=2,
+ ch_mult=(1,),
+ num_res_blocks=1,
+ attn_resolutions=None,
+ in_channels=2,
+ resolution=32,
+ latent_channels=2,
+ norm_type="pixel",
+ causality_axis="height",
+ dropout=0.0,
+ mid_block_add_attention=False,
+ sample_rate=16000,
+ mel_hop_length=160,
+ is_causal=True,
+ mel_bins=8,
+ )
+
+ torch.manual_seed(0)
+ vocoder = LTX2Vocoder(
+ in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
+ hidden_channels=32,
+ out_channels=2,
+ upsample_kernel_sizes=[4, 4],
+ upsample_factors=[2, 2],
+ resnet_kernel_sizes=[3],
+ resnet_dilations=[[1, 3, 5]],
+ leaky_relu_negative_slope=0.1,
+ output_sampling_rate=16000,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "audio_vae": audio_vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "connectors": connectors,
+ "vocoder": vocoder,
+ }
+
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "a robot dancing",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "num_frames": 5,
+ "frame_rate": 25.0,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = pipe(**inputs)
+ video = output.frames
+ audio = output.audio
+
+ self.assertEqual(video.shape, (1, 5, 3, 32, 32))
+ self.assertEqual(audio.shape[0], 1)
+ self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
+
+ # fmt: off
+ expected_video_slice = torch.tensor(
+ [
+ 0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184
+ ]
+ )
+ expected_audio_slice = torch.tensor(
+ [
+ 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
+ ]
+ )
+ # fmt: on
+
+ video = video.flatten()
+ audio = audio.flatten()
+ generated_video_slice = torch.cat([video[:8], video[-8:]])
+ generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
+
+ assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
+ assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py
new file mode 100644
index 000000000000..1edae9c0e098
--- /dev/null
+++ b/tests/pipelines/ltx2/test_ltx2_image2video.py
@@ -0,0 +1,241 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
+
+from diffusers import (
+ AutoencoderKLLTX2Audio,
+ AutoencoderKLLTX2Video,
+ FlowMatchEulerDiscreteScheduler,
+ LTX2ImageToVideoPipeline,
+ LTX2VideoTransformer3DModel,
+)
+from diffusers.pipelines.ltx2 import LTX2TextConnectors
+from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
+
+from ...testing_utils import enable_full_determinism
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = LTX2ImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "audio_latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_attention_slicing = False
+ test_xformers_attention = False
+ supports_dduf = False
+
+ base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
+
+ def get_dummy_components(self):
+ tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
+ text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
+
+ torch.manual_seed(0)
+ transformer = LTX2VideoTransformer3DModel(
+ in_channels=4,
+ out_channels=4,
+ patch_size=1,
+ patch_size_t=1,
+ num_attention_heads=2,
+ attention_head_dim=8,
+ cross_attention_dim=16,
+ audio_in_channels=4,
+ audio_out_channels=4,
+ audio_num_attention_heads=2,
+ audio_attention_head_dim=4,
+ audio_cross_attention_dim=8,
+ num_layers=2,
+ qk_norm="rms_norm_across_heads",
+ caption_channels=text_encoder.config.text_config.hidden_size,
+ rope_double_precision=False,
+ rope_type="split",
+ )
+
+ torch.manual_seed(0)
+ connectors = LTX2TextConnectors(
+ caption_channels=text_encoder.config.text_config.hidden_size,
+ text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
+ video_connector_num_attention_heads=4,
+ video_connector_attention_head_dim=8,
+ video_connector_num_layers=1,
+ video_connector_num_learnable_registers=None,
+ audio_connector_num_attention_heads=4,
+ audio_connector_attention_head_dim=8,
+ audio_connector_num_layers=1,
+ audio_connector_num_learnable_registers=None,
+ connector_rope_base_seq_len=32,
+ rope_theta=10000.0,
+ rope_double_precision=False,
+ causal_temporal_positioning=False,
+ rope_type="split",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLLTX2Video(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ block_out_channels=(8,),
+ decoder_block_out_channels=(8,),
+ layers_per_block=(1,),
+ decoder_layers_per_block=(1, 1),
+ spatio_temporal_scaling=(True,),
+ decoder_spatio_temporal_scaling=(True,),
+ decoder_inject_noise=(False, False),
+ downsample_type=("spatial",),
+ upsample_residual=(False,),
+ upsample_factor=(1,),
+ timestep_conditioning=False,
+ patch_size=1,
+ patch_size_t=1,
+ encoder_causal=True,
+ decoder_causal=False,
+ )
+ vae.use_framewise_encoding = False
+ vae.use_framewise_decoding = False
+
+ torch.manual_seed(0)
+ audio_vae = AutoencoderKLLTX2Audio(
+ base_channels=4,
+ output_channels=2,
+ ch_mult=(1,),
+ num_res_blocks=1,
+ attn_resolutions=None,
+ in_channels=2,
+ resolution=32,
+ latent_channels=2,
+ norm_type="pixel",
+ causality_axis="height",
+ dropout=0.0,
+ mid_block_add_attention=False,
+ sample_rate=16000,
+ mel_hop_length=160,
+ is_causal=True,
+ mel_bins=8,
+ )
+
+ torch.manual_seed(0)
+ vocoder = LTX2Vocoder(
+ in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
+ hidden_channels=32,
+ out_channels=2,
+ upsample_kernel_sizes=[4, 4],
+ upsample_factors=[2, 2],
+ resnet_kernel_sizes=[3],
+ resnet_dilations=[[1, 3, 5]],
+ leaky_relu_negative_slope=0.1,
+ output_sampling_rate=16000,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "audio_vae": audio_vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "connectors": connectors,
+ "vocoder": vocoder,
+ }
+
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
+
+ inputs = {
+ "image": image,
+ "prompt": "a robot dancing",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "num_frames": 5,
+ "frame_rate": 25.0,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = pipe(**inputs)
+ video = output.frames
+ audio = output.audio
+
+ self.assertEqual(video.shape, (1, 5, 3, 32, 32))
+ self.assertEqual(audio.shape[0], 1)
+ self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
+
+ # fmt: off
+ expected_video_slice = torch.tensor(
+ [
+ 0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555
+ ]
+ )
+ expected_audio_slice = torch.tensor(
+ [
+ 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
+ ]
+ )
+ # fmt: on
+
+ video = video.flatten()
+ audio = audio.flatten()
+ generated_video_slice = torch.cat([video[:8], video[-8:]])
+ generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
+
+ assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
+ assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index fde3966dec97..031fdc9f9e27 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -288,31 +288,29 @@ def test_config_from_pretrained(self):
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
+ @require_bitsandbytes_version_greater("0.48.0")
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Checks also if other models are casted correctly.
"""
- with self.assertRaises(ValueError):
- # Tries with `str`
- self.model_8bit.to("cpu")
with self.assertRaises(ValueError):
# Tries with a `dtype``
self.model_8bit.to(torch.float16)
- with self.assertRaises(ValueError):
- # Tries with a `device`
- self.model_8bit.to(torch.device(f"{torch_device}:0"))
-
with self.assertRaises(ValueError):
# Tries with a `device`
self.model_8bit.float()
with self.assertRaises(ValueError):
- # Tries with a `device`
+ # Tries with a `dtype`
self.model_8bit.half()
+ # This should work with 0.48.0
+ self.model_8bit.to("cpu")
+ self.model_8bit.to(torch.device(f"{torch_device}:0"))
+
# Test if we did not break anything
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
input_dict_for_transformer = self.get_dummy_inputs()
@@ -837,7 +835,7 @@ def test_serialization_sharded(self):
@require_torch_version_greater_equal("2.6.0")
-@require_bitsandbytes_version_greater("0.45.5")
+@require_bitsandbytes_version_greater("0.48.0")
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
@@ -848,7 +846,7 @@ def quantization_config(self):
)
@pytest.mark.xfail(
- reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
+ reason="Test fails because of a type change when recompiling."
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
)
def test_torch_compile(self):
@@ -858,6 +856,5 @@ def test_torch_compile(self):
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
- @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py
index e6bfc2530a5a..7a8e3cc67877 100644
--- a/tests/quantization/torchao/test_torchao.py
+++ b/tests/quantization/torchao/test_torchao.py
@@ -256,9 +256,12 @@ def test_quantization(self):
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
- ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
- ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
+ if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
+ QUANTIZATION_TYPES_TO_TEST.extend([
+ ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
+ ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
+ ])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
@@ -271,6 +274,34 @@ def test_quantization(self):
)
self._test_quant_type(quantization_config, expected_slice, model_id)
+ @unittest.skip("Skipping floatx quantization tests")
+ def test_floatx_quantization(self):
+ for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
+ if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
+ if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
+ quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
+ self._test_quant_type(
+ quantization_config,
+ np.array(
+ [
+ 0.4648,
+ 0.5195,
+ 0.5547,
+ 0.4180,
+ 0.4434,
+ 0.6445,
+ 0.4316,
+ 0.4531,
+ 0.5625,
+ ]
+ ),
+ model_id,
+ )
+ else:
+ # Make sure the correct error is thrown
+ with self.assertRaisesRegex(ValueError, "Please downgrade"):
+ quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
+
def test_int4wo_quant_bfloat16_conversion(self):
"""
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
@@ -794,8 +825,11 @@ def test_quantization(self):
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
- ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
+ if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
+ QUANTIZATION_TYPES_TO_TEST.extend([
+ ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
+ ])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: