Skip to content

Conversation

leejet
Copy link
Owner

@leejet leejet commented Sep 22, 2025

txt2img

.\bin\Release\sd.exe --diffusion-model  ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors  --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf  -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
qwen_image_t2i

img2img

.\bin\Release\sd.exe --diffusion-model  ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors  --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -i ..\assets\flux\flux1-dev-q8_0.png -p "a lovely cat"
qwen_image_i2i

Qwen Image Edit

#877

@leejet leejet mentioned this pull request Sep 22, 2025
@SeanTater
Copy link

Thanks for adding this! I got it working on CPU on my machine, but as you would expect, it's quite slow.
I tried compiling with Vulkan, which compiles, but segfaults immediately as it starts the diffusion. Are you already working on that?

FWIW, Codex suggests changing ggml_vk_build_graph, which does get it to compute something - but its nonsense results. I get garbled output which doesn't appear to depend on the prompt. It's the same with or without diffusion-fa. With vae tiling, I get a floating point exception.
2025-09-23T12:54:42-04:00
When doing the VAE on the CPU instead, we have a different problem: we get tiled field like this, the exact color of which varies.
vulkan-variation-01-seed1000

I suspect maybe there is an as-yet-unimplemented op that it's basically just stubbing.

@jeffbolznv
Copy link

Where/how does it crash with Vulkan?

@wbruna
Copy link
Contributor

wbruna commented Sep 24, 2025

Where/how does it crash with Vulkan?

Testing it here, I get:

$ ./sd --diffusion-model ./qwen-image-Q4_0.gguf --vae ./Qwen_Image-VAE.safetensors --qwen2vl ./Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf -p 一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。” --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 512 -W 512 --diffusion-fa --flow-shift 3
Option: 
    n_threads:                         4
    mode:                              img_gen
    model_path:                        
    wtype:                             unspecified
    clip_l_path:                       
    clip_g_path:                       
    clip_vision_path:                  
    t5xxl_path:                        
    qwen2vl_path:                      ./Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf
    diffusion_model_path:              ./qwen-image-Q4_0.gguf
    high_noise_diffusion_model_path:   
    vae_path:                          ./Qwen_Image-VAE.safetensors
    taesd_path:                        
    esrgan_path:                       
    control_net_path:                  
    embedding_dir:                     
    photo_maker_path:                  
    pm_id_images_dir:                  
    pm_id_embed_path:                  
    pm_style_strength:                 20.00
    output_path:                       output.png
    init_image_path:                   
    end_image_path:                    
    mask_image_path:                   
    control_image_path:                
    ref_images_paths:
    control_video_path:                
    increase_ref_index:                false
    offload_params_to_cpu:             true
    clip_on_cpu:                       false
    control_net_cpu:                   false
    vae_on_cpu:                        false
    diffusion flash attention:         true
    diffusion Conv2d direct:           false
    vae_conv_direct:                   false
    control_strength:                  0.90
    prompt:                            一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”
    negative_prompt:                   
    clip_skip:                         -1
    width:                             512
    height:                            512
    sample_params:                     (txt_cfg: 2.50, img_cfg: 2.50, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: euler, sample_steps: 20, eta: 0.00, shifted_timestep: 0)
    high_noise_sample_params:          (txt_cfg: 7.00, img_cfg: 7.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: default, sample_steps: -1, eta: 0.00, shifted_timestep: 0)
    moe_boundary:                      0.875
    flow_shift:                        3.00
    strength(img2img):                 0.75
    rng:                               cuda
    seed:                              42
    batch_count:                       1
    vae_tiling:                        false
    upscale_repeats:                   1
    chroma_use_dit_mask:               true
    chroma_use_t5_mask:                false
    chroma_t5_mask_pad:                1
    video_frames:                      1
    vace_strength:                     1.00
    fps:                               16
System Info: 
    SSE3 = 1
    AVX = 1
    AVX2 = 1
    AVX512 = 0
    AVX512_VBMI = 0
    AVX512_VNNI = 0
    FMA = 1
    NEON = 0
    ARM_FMA = 0
    F16C = 1
    FP16_VA = 0
    WASM_SIMD = 0
    VSX = 0
[DEBUG] stable-diffusion.cpp:153  - Using Vulkan backend
[DEBUG] ggml_extend.hpp:62   - ggml_vulkan: Found 1 Vulkan devices:
[DEBUG] ggml_extend.hpp:62   - ggml_vulkan: 0 = AMD Radeon RX 7600 XT (RADV NAVI33) (radv) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
[INFO ] stable-diffusion.cpp:209  - loading diffusion model from './qwen-image-Q4_0.gguf'
[INFO ] model.cpp:1071 - load ./qwen-image-Q4_0.gguf using gguf format
[DEBUG] model.cpp:1088 - init from './qwen-image-Q4_0.gguf'
[INFO ] stable-diffusion.cpp:256  - loading qwen2vl from './Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf'
[INFO ] model.cpp:1071 - load ./Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf using gguf format
[DEBUG] model.cpp:1088 - init from './Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf'
[INFO ] stable-diffusion.cpp:263  - loading vae from './Qwen_Image-VAE.safetensors'
[INFO ] model.cpp:1074 - load ./Qwen_Image-VAE.safetensors using safetensors format
[DEBUG] model.cpp:1181 - init from './Qwen_Image-VAE.safetensors', prefix = 'vae.'
[INFO ] stable-diffusion.cpp:275  - Version: Qwen Image 
[INFO ] stable-diffusion.cpp:306  - Weight type:                 bf16
[INFO ] stable-diffusion.cpp:307  - Conditioner weight type:     f32
[INFO ] stable-diffusion.cpp:308  - Diffusion model weight type: bf16
[INFO ] stable-diffusion.cpp:309  - VAE weight type:             NONE
[DEBUG] stable-diffusion.cpp:311  - ggml tensor size = 400 bytes
[INFO ] stable-diffusion.cpp:350  - Using flash attention in the diffusion model
[DEBUG] qwenvl.hpp:137  - merges size 151387
[DEBUG] qwenvl.hpp:159  - vocab size: 151665
[DEBUG] ggml_extend.hpp:1738 - qwenvl2.5 params backend buffer size =  3607.26 MB(RAM) (338 tensors)
[DEBUG] ggml_extend.hpp:1738 - qwen_image params backend buffer size =  11303.54 MB(RAM) (1933 tensors)
[DEBUG] ggml_extend.hpp:1738 - wan_vae params backend buffer size =  139.84 MB(RAM) (108 tensors)
[DEBUG] stable-diffusion.cpp:583  - loading weights
[DEBUG] model.cpp:2069 - loading tensors from ./qwen-image-Q4_0.gguf
  |=======================================>          | 1933/2465 - 804.75it/s
[DEBUG] model.cpp:2069 - loading tensors from ./Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf
  |==============================================>   | 2271/2465 - 222.34it/s
[DEBUG] model.cpp:2069 - loading tensors from ./Qwen_Image-VAE.safetensors
  |==============================================>   | 2283/2465 - 223.49it/s[INFO ] model.cpp:2339 - unknown tensor 'first_stage_model.conv1.weight | bf16 | 4 [1, 1, 1, 1024, 1]' in model file
  |================================================> | 2393/2465 - 229.76it/s[INFO ] model.cpp:2339 - unknown tensor 'first_stage_model.conv1.bias | bf16 | 1 [32, 1, 1, 1, 1]' in model file
  |==================================================| 2465/2465 - 232.22it/s
[INFO ] model.cpp:2307 - loading tensors completed, taking 10.65s (process: 0.04s, read: 9.94s, memcpy: 0.00s, convert: 0.10s, copy_to_backend: 0.00s)
[INFO ] stable-diffusion.cpp:664  - total params memory size = 15050.64MB (VRAM 15050.64MB, RAM 0.00MB): text_encoders 3607.26MB(VRAM), diffusion_model 11303.55MB(VRAM), vae 139.84MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)
[INFO ] stable-diffusion.cpp:726  - running in FLOW mode
[DEBUG] stable-diffusion.cpp:750  - finished loaded file
[DEBUG] stable-diffusion.cpp:2328 - generate_image 512x512
[INFO ] stable-diffusion.cpp:2441 - TXT2IMG
init (f32): shape(64, 64, 16, 1)
[INFO ] stable-diffusion.cpp:899  - attempting to apply 0 LoRAs
[INFO ] stable-diffusion.cpp:919  - apply_loras completed, taking 0.00s
[DEBUG] stable-diffusion.cpp:920  - prompt after extract and remove lora: "一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”"
[DEBUG] conditioner.hpp:1416 - parse '<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔�
[INFO ] ggml_extend.hpp:1661 - qwenvl2.5 offload params (3607.26 MB, 338 tensors) to runtime backend (Vulkan0), taking 1.49s
[DEBUG] ggml_extend.hpp:1563 - qwenvl2.5 compute buffer size: 30.06 MB(VRAM)
Segmentation fault (core dumped)

gdb shows just this:

Thread 1 "sd" received signal SIGSEGV, Segmentation fault.
0x000055555587569d in ggml_vk_build_graph(ggml_backend_vk_context*, ggml_cgraph*, int, ggml_tensor*, int, bool, bool, bool, bool) ()

I'll try on a debug build. @jeffbolznv , anything more specific I could check?

@leejet
Copy link
Owner Author

leejet commented Sep 24, 2025

@SeanTater @wbruna This is likely because GGML Vulkan doesn’t support im2col_3d. I’ve updated GGML, so you can pull the latest code and try again.

@wbruna
Copy link
Contributor

wbruna commented Sep 24, 2025

@leejet , unfortunately a3a2b2d (with ggml 553c44706c ) crashes too:

the last output lines
ggml_backend_vk_buffer_init_tensor(0x55555963f8f0 (0x555559b0ce40), 0x7ffbc551c020)
ggml_backend_vk_buffer_init_tensor(0x55555963f8f0 (0x555559b0ce40), 0x7ffbc551c1d0)
ggml_backend_vk_buffer_set_tensor(0x55555963f8f0, 0x7ffbc548e060, 0x555559602820, 0, 4)
ggml_vk_buffer_write(4)
ggml_vk_buffer_write_2d(4, 1)
ggml_vk_create_temporary_context(0x55555a1ff900)
ggml_vk_ctx_begin(Vulkan1)
ggml_vk_create_cmd_buffer()
ggml_vk_buffer_write_2d_async(4, 1)
STAGING
ggml_vk_sync_buffers()
ggml_vk_ctx_end(0x55555a1ff900, 1)
ggml_vk_submit(0x55555a1ff900, 0x55555959c410)
ggml_vk_queue_command_pools_cleanup()
ggml_backend_vk_buffer_set_tensor(0x55555963f8f0, 0x7ffbc54a24c0, 0x7ffbe003f3a0, 0, 160)
ggml_vk_buffer_write(160)
ggml_vk_buffer_write_2d(160, 1)
ggml_vk_create_temporary_context(0x55555a1ff900)
ggml_vk_ctx_begin(Vulkan1)
ggml_vk_create_cmd_buffer()
ggml_vk_buffer_write_2d_async(160, 1)
STAGING
ggml_vk_sync_buffers()
ggml_vk_ctx_end(0x55555a1ff900, 1)
ggml_vk_submit(0x55555a1ff900, 0x55555959c410)
ggml_vk_queue_command_pools_cleanup()
ggml_vk_command_pool_cleanup()
ggml_backend_vk_buffer_set_tensor(0x55555963f8f0, 0x7ffbc54a2670, 0x55555ab88040, 0, 640)
ggml_vk_buffer_write(640)
ggml_vk_buffer_write_2d(640, 1)
ggml_vk_create_temporary_context(0x55555a1ff900)
ggml_vk_ctx_begin(Vulkan1)
ggml_vk_create_cmd_buffer()
ggml_vk_buffer_write_2d_async(640, 1)
STAGING
ggml_vk_sync_buffers()
ggml_vk_ctx_end(0x55555a1ff900, 1)
ggml_vk_submit(0x55555a1ff900, 0x55555959c410)
ggml_vk_queue_command_pools_cleanup()
ggml_backend_vk_graph_compute(1154 nodes)
ggml_vk_build_graph(0x7ffbc54a2820, RESHAPE)
ggml_vk_build_graph(0x7ffbc54a29d0, RESHAPE)
ggml_vk_build_graph(0x7ffbc54a2b80, GET_ROWS)
ggml_pipeline_request_descriptor_sets(
Thread 1 "sd" received signal SIGSEGV, Segmentation fault.
0x00007ffff7d54b24 in std::basic_ostream<char, std::char_traits<char> >& std::operator<< <char, std::char_traits<char>, std::allocator<char> >(std::basic_ostream<char, std::char_traits<char> >&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) ()
   from /lib/x86_64-linux-gnu/libstdc++.so.6
(gdb)
GDB backtrace
(gdb) bt
#0  0x00007ffff7d54b24 in std::basic_ostream<char, std::char_traits<char> >& std::operator<< <char, std::char_traits<char>, std::allocator<char> >(std::basic_ostream<char, std::char_traits<char> >&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) ()
   from /lib/x86_64-linux-gnu/libstdc++.so.6
#1  0x00005555558c25b2 in ggml_pipeline_request_descriptor_sets (ctx=0x5555594e62e0, pipeline=std::shared_ptr<vk_pipeline_struct> (empty) = {...}, n=1)
    at ./ggml/src/ggml-vulkan/ggml-vulkan.cpp:1653
#2  0x00005555559a38f3 in ggml_vk_build_graph (ctx=0x5555594e62e0, cgraph=0x7ffbc548e210, node_idx=2, node_begin=0x0, node_idx_begin=0, dryrun=true, 
    last_node=false, almost_ready=false, submit=false) at ./ggml/src/ggml-vulkan/ggml-vulkan.cpp:10650
#3  0x00005555559a9e1e in ggml_backend_vk_graph_compute (backend=0x555559547bb0, cgraph=0x7ffbc548e210)
    at ./ggml/src/ggml-vulkan/ggml-vulkan.cpp:11743
#4  0x0000555555a79755 in ggml_backend_graph_compute_async (backend=0x555559547bb0, cgraph=0x7ffbc548e210)
    at ./ggml/src/ggml-backend.cpp:359
#5  0x0000555555a796e5 in ggml_backend_graph_compute (backend=0x555559547bb0, cgraph=0x7ffbc548e210)
    at ./ggml/src/ggml-backend.cpp:352
#6  0x0000555555682db2 in GGMLRunner::compute(std::function<ggml_cgraph* ()>, int, bool, ggml_tensor**, ggml_context*) (this=0x555559d6b550, 
    get_graph=..., n_threads=1, free_compute_buffer_immediately=true, output=0x7fffffffb3a8, output_ctx=0x55555942bf60)
    at ./ggml_extend.hpp:1824
#7  0x0000555555694938 in Qwen::Qwen2_5_VLRunner::compute (this=0x555559d6b550, n_threads=1, input_ids=0x7ffbe003f210, output=0x7fffffffb3a8, 
    output_ctx=0x55555942bf60) at ./qwenvl.hpp:603
#8  0x00005555556a7e18 in Qwen2_5_VLCLIPEmbedder::get_learned_condition_common (this=0x55555942b9b0, work_ctx=0x55555942bf60, n_threads=1, 
    token_and_weights=std::tuple containing = {...}, clip_skip=-1, zero_out_masked=false) at ./conditioner.hpp:1452
#9  0x00005555556a8287 in Qwen2_5_VLCLIPEmbedder::get_learned_condition (this=0x55555942b9b0, work_ctx=0x55555942bf60, n_threads=1, text="flower", 
    clip_skip=-1, width=512, height=512, adm_in_channels=768, zero_out_masked=false) at ./conditioner.hpp:1500
#10 0x000055555565f7d0 in generate_image_internal (sd_ctx=0x55555942a480, work_ctx=0x55555942bf60, init_latent=0x7ffbdffff060, prompt="flower", 
    negative_prompt="", clip_skip=-1, guidance=..., eta=0, shifted_timestep=0, width=512, height=512, sample_method=EULER, 
    sigmas=std::vector of length 21, capacity 32 = {...}, seed=42, batch_count=1, control_image=..., control_strength=0.899999976, pm_params=..., 
    ref_latents=std::vector of length 0, capacity 0, increase_ref_index=false, concat_latent=0x0, denoise_mask=0x0)
    at ./stable-diffusion.cpp:2086
#11 0x0000555555661c6b in generate_image (sd_ctx=0x55555942a480, sd_img_gen_params=0x7fffffffbe70)
    at ./stable-diffusion.cpp:2492
#12 0x00005555555add61 in main (argc=25, argv=0x7fffffffd828) at ./examples/cli/main.cpp:1392

@jeffbolznv
Copy link

What are the src and dst types for the GET_ROWS that crashes?

@wbruna
Copy link
Contributor

wbruna commented Sep 24, 2025

(gdb) frame 2
#2  0x00005555559a38f3 in ggml_vk_build_graph (ctx=0x5555594e62e0, cgraph=0x7ffbc548e210, node_idx=2, node_begin=0x0, node_idx_begin=0, dryrun=true, 
    last_node=false, almost_ready=false, submit=false) at ./ggml/src/ggml-vulkan/ggml-vulkan.cpp:10650
10650                   ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
(gdb) print src0->type
$6 = GGML_TYPE_Q4_K
(gdb) print src1->type
$7 = GGML_TYPE_I32
(gdb) print src2->type
Cannot access memory at address 0x0
(gdb) print node->op
$8 = GGML_OP_GET_ROWS

Interesting... the model files are qwen-image-Q4_0.gguf and Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf .

@jeffbolznv
Copy link

Thanks. We're missing the K quants but I don't think there's any reason for this. I'll add it.

@jeffbolznv
Copy link

Please try ggml-org/llama.cpp#16235.

@SeanTater
Copy link

SeanTater commented Sep 25, 2025

After applying the change from Jeff's PR in llama.cpp to the ggml submodule in stable-diffusion.cpp, it does run, no crash. But I get garbed output, and even though it does recognize the devices:

[DEBUG] stable-diffusion.cpp:153  - Using Vulkan backend
[DEBUG] ggml_extend.hpp:62   - ggml_vulkan: Found 2 Vulkan devices:
[DEBUG] ggml_extend.hpp:62   - ggml_vulkan: 0 = AMD Radeon RX 7900 XTX (RADV NAVI31) (radv) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 0 | matrix cores: KHR_coopmat
[DEBUG] ggml_extend.hpp:62   - ggml_vulkan: 1 = Intel(R) Graphics (RPL-S) (Intel open-source Mesa driver) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 0 | matrix cores: none

.. and it swears it places them on VRAM ..

[INFO ] stable-diffusion.cpp:664  - total params memory size = 16634.26MB (VRAM 16634.26MB, RAM 0.00MB): text_encoders 4034.09MB(VRAM), diffusion_model 12460.33MB(VRAM), vae 139.84MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)

rocm-smi disagrees:

Device  Node  IDs              Temp    Power  Partitions          SCLK  MCLK   Fan  Perf  PwrCap  VRAM%  GPU%  
              (DID,     GUID)  (Edge)  (Avg)  (Mem, Compute, ID)                                               
================================================================================================================
0       1     0x744c,   57282  37.0°C  6.0W   N/A, N/A, 0         0Mhz  96Mhz  0%   auto  327.0W  11%    0%    

It does finish in 17 seconds per step as opposed to about 70 for successful CPU sampling, but I think that may be a red herring since the output is garbage and the GPU is idle

@wbruna
Copy link
Contributor

wbruna commented Sep 25, 2025

After applying ggml-org/llama.cpp@9073a73 and ggml-org/llama.cpp#16235 , I got a broken image too:

./sd --diffusion-model ./Qwen_Image_Distill-Q4_0.gguf --vae ./Qwen_Image-VAE-f16.gguf --qwen2vl ./Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf -p '(...)' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 512 -W 512 --diffusion-fa --steps 20

testqwen01

VAE tiling also crashes with a Floating point exception(core dumped) .

@jeffbolznv
Copy link

I'm seeing similar corruption. I'll try to debug it.

@leejet
Copy link
Owner Author

leejet commented Sep 25, 2025

I updated ggml to the latest commit and optimized the handling of embedding weights, so there’s no need to use k_quant’s get_rows. I’m not sure if this will fix the Vulkan issue.

@jeffbolznv
Copy link

I don't think it's related to get_rows. Setting GGML_VK_DISABLE_FUSION=1 seems to fix it. I'll continue to narrow it down.

@jeffbolznv
Copy link

Oops, I think I mixed up my experiments. I think it's forcing GGML_PREC_F32 for matrix-matrix multiplies that's fixing it. I don't know which multiplies, I just forced it for all of them.

@leejet
Copy link
Owner Author

leejet commented Oct 10, 2025

@wbruna Could you show me your detailed parameters for comparison?

@wbruna
Copy link
Contributor

wbruna commented Oct 10, 2025

@leejet , testing b769da2 with the Lightning distill:

[DEBUG] ggml_extend.hpp:62 - ggml_vulkan: Found 1 Vulkan devices:
[DEBUG] ggml_extend.hpp:62 - ggml_vulkan: 0 = AMD Radeon RX 7600 XT (RADV NAVI33) (radv) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat

sd --diffusion-model qwen-image-lighting-8steps-V1.0-Q4_K_S.gguf --qwen2vl Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf -W 320 -H 320 --vae Qwen_Image-VAE.safetensors --sampling-method euler --steps 8 --cfg-scale 1 -p "a cartoon flower" --diffusion-fa --offload-to-cpu -o ./teste_1760121748_1.png

teste_1760121748_1

The same command with ROCm, just replacing the binary, renders a black image. The ROCm build command is:

cmake -B sd_hipblas ./sd.cpp -DSD_BUILD_SHARED_LIBS=ON -DSD_HIPBLAS=ON -DGGML_HIP_ROCWMMA_FATTN=ON -DGPU_TARGETS=gfx1102

full ROCM output (Vulkan is identical except for the hardware info and the timings)
Option: 
    n_threads:                         4
    mode:                              img_gen
    model_path:                        
    wtype:                             unspecified
    clip_l_path:                       
    clip_g_path:                       
    clip_vision_path:                  
    t5xxl_path:                        
    qwen2vl_path:                      Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf
    qwen2vl_vision_path:               
    diffusion_model_path:              qwen-image-lighting-8steps-V1.0-Q4_K_S.gguf
    high_noise_diffusion_model_path:   
    vae_path:                          Qwen_Image-VAE.safetensors
    taesd_path:                        
    esrgan_path:                       
    control_net_path:                  
    embedding_dir:                     
    photo_maker_path:                  
    pm_id_images_dir:                  
    pm_id_embed_path:                  
    pm_style_strength:                 20.00
    output_path:                       ./teste_1760121748_2.png
    init_image_path:                   
    end_image_path:                    
    mask_image_path:                   
    control_image_path:                
    ref_images_paths:
    control_video_path:                
    increase_ref_index:                false
    offload_params_to_cpu:             true
    clip_on_cpu:                       false
    control_net_cpu:                   false
    vae_on_cpu:                        false
    diffusion flash attention:         true
    diffusion Conv2d direct:           false
    vae_conv_direct:                   false
    control_strength:                  0.90
    prompt:                            a cartoon flower
    negative_prompt:                   
    clip_skip:                         -1
    width:                             320
    height:                            320
    sample_params:                     (txt_cfg: 1.00, img_cfg: 1.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: euler, sample_steps: 8, eta: 0.00, shifted_timestep: 0)
    high_noise_sample_params:          (txt_cfg: 7.00, img_cfg: 7.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: default, sample_steps: -1, eta: 0.00, shifted_timestep: 0)
    moe_boundary:                      0.875
    flow_shift:                        inf
    strength(img2img):                 0.75
    rng:                               cuda
    seed:                              42
    batch_count:                       1
    vae_tiling:                        false
    upscale_repeats:                   1
    chroma_use_dit_mask:               true
    chroma_use_t5_mask:                false
    chroma_t5_mask_pad:                1
    video_frames:                      1
    vace_strength:                     1.00
    fps:                               16
System Info: 
    SSE3 = 1
    AVX = 1
    AVX2 = 1
    AVX512 = 0
    AVX512_VBMI = 0
    AVX512_VNNI = 0
    FMA = 1
    NEON = 0
    ARM_FMA = 0
    F16C = 1
    FP16_VA = 0
    WASM_SIMD = 0
    VSX = 0
[DEBUG] stable-diffusion.cpp:147  - Using CUDA backend
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: found 2 ROCm devices:
[INFO ] ggml_extend.hpp:65   -   Device 0: AMD Radeon RX 7600 XT, gfx1102 (0x1102), VMM: no, Wave Size: 32
[INFO ] ggml_extend.hpp:65   -   Device 1: AMD Radeon Vega 11 Graphics, gfx902:xnack- (0x902), VMM: no, Wave Size: 64
[INFO ] stable-diffusion.cpp:211  - loading diffusion model from 'qwen-image-lighting-8steps-V1.0-Q4_K_S.gguf'
[INFO ] model.cpp:1098 - load qwen-image-lighting-8steps-V1.0-Q4_K_S.gguf using gguf format
[DEBUG] model.cpp:1115 - init from 'qwen-image-lighting-8steps-V1.0-Q4_K_S.gguf'
[INFO ] stable-diffusion.cpp:258  - loading qwen2vl from 'Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf'
[INFO ] model.cpp:1098 - load Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf using gguf format
[DEBUG] model.cpp:1115 - init from 'Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf'
[INFO ] stable-diffusion.cpp:272  - loading vae from 'Qwen_Image-VAE.safetensors'
[INFO ] model.cpp:1101 - load Qwen_Image-VAE.safetensors using safetensors format
[DEBUG] model.cpp:1208 - init from 'Qwen_Image-VAE.safetensors', prefix = 'vae.'
[INFO ] stable-diffusion.cpp:293  - Version: Qwen Image 
[INFO ] stable-diffusion.cpp:324  - Weight type:                 q4_K
[INFO ] stable-diffusion.cpp:325  - Conditioner weight type:     f32
[INFO ] stable-diffusion.cpp:326  - Diffusion model weight type: q4_K
[INFO ] stable-diffusion.cpp:327  - VAE weight type:             NONE
[DEBUG] stable-diffusion.cpp:329  - ggml tensor size = 400 bytes
[INFO ] stable-diffusion.cpp:373  - Using flash attention in the diffusion model
[DEBUG] qwenvl.hpp:139  - merges size 151387
[DEBUG] qwenvl.hpp:161  - vocab size: 151665
[ERROR] qwen_image.hpp:534  - qwen_image_params.num_layers: 60
[DEBUG] ggml_extend.hpp:1745 - qwenvl2.5 params backend buffer size =  5393.90 MB(RAM) (338 tensors)
[DEBUG] ggml_extend.hpp:1745 - qwen_image params backend buffer size =  10978.29 MB(RAM) (1933 tensors)
[DEBUG] ggml_extend.hpp:1745 - wan_vae params backend buffer size =  139.84 MB(RAM) (108 tensors)
[DEBUG] stable-diffusion.cpp:612  - loading weights
[DEBUG] model.cpp:2031 - using 4 threads for model loading
[DEBUG] model.cpp:2114 - loading tensors from qwen-image-lighting-8steps-V1.0-Q4_K_S.gguf
  |=======================================>          | 1933/2465 - 95.32it/s
[DEBUG] model.cpp:2114 - loading tensors from Qwen2.5-VL-7B-Instruct-IQ4_XS.gguf
  |==============================================>   | 2271/2465 - 85.76it/s
[DEBUG] model.cpp:2114 - loading tensors from Qwen_Image-VAE.safetensors
  |==================================================| 2465/2465 - 91.70it/s
[INFO ] model.cpp:2352 - loading tensors completed, taking 26.93s (process: 0.04s, read: 26.24s, memcpy: 0.00s, convert: 0.44s, copy_to_backend: 0.00s)
[INFO ] stable-diffusion.cpp:695  - total params memory size = 16512.03MB (VRAM 16512.03MB, RAM 0.00MB): text_encoders 5393.90MB(VRAM), diffusion_model 10978.29MB(VRAM), vae 139.84MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)
[INFO ] stable-diffusion.cpp:757  - running in FLOW mode
[DEBUG] stable-diffusion.cpp:781  - finished loaded file
[DEBUG] stable-diffusion.cpp:2406 - generate_image 320x320
[INFO ] stable-diffusion.cpp:2533 - TXT2IMG
[INFO ] stable-diffusion.cpp:930  - attempting to apply 0 LoRAs
[INFO ] stable-diffusion.cpp:950  - apply_loras completed, taking 0.00s
[DEBUG] stable-diffusion.cpp:951  - prompt after extract and remove lora: "a cartoon flower"
[DEBUG] conditioner.hpp:1433 - parse '<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
a cartoon flower<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
a cartoon flower<|im_end|>
<|im_start|>assistant
', 1], ]
[INFO ] ggml_extend.hpp:1668 - qwenvl2.5 offload params (5393.90 MB, 338 tensors) to runtime backend (ROCm0), taking 1.57s
[DEBUG] ggml_extend.hpp:1570 - qwenvl2.5 compute buffer size: 7.79 MB(VRAM)
[DEBUG] conditioner.hpp:1572 - computing condition graph completed, taking 1773 ms
[INFO ] stable-diffusion.cpp:2144 - get_learned_condition completed, taking 1776 ms
[INFO ] stable-diffusion.cpp:2169 - sampling using Euler method
[INFO ] stable-diffusion.cpp:2263 - generating image: 1/1 - seed 42
[INFO ] ggml_extend.hpp:1668 - qwen_image offload params (10978.29 MB, 1933 tensors) to runtime backend (ROCm0), taking 2.15s
[DEBUG] ggml_extend.hpp:1570 - qwen_image compute buffer size: 38.53 MB(VRAM)
  |==================================================| 8/8 - 1.73s/it
[INFO ] stable-diffusion.cpp:2300 - sampling completed, taking 19.23s
[INFO ] stable-diffusion.cpp:2308 - generating 1 latent images completed, taking 20.61s
[INFO ] stable-diffusion.cpp:2311 - decoding 1 latents
[INFO ] ggml_extend.hpp:1668 - wan_vae offload params (139.84 MB, 108 tensors) to runtime backend (ROCm0), taking 0.09s
[DEBUG] ggml_extend.hpp:1570 - wan_vae compute buffer size: 732.76 MB(VRAM)
[DEBUG] stable-diffusion.cpp:1636 - computing vae decode graph completed, taking 0.57s
[INFO ] stable-diffusion.cpp:2321 - latent 1 decoded, taking 0.57s
[INFO ] stable-diffusion.cpp:2325 - decode_first_stage completed, taking 0.57s
[INFO ] stable-diffusion.cpp:2639 - generate_image completed in 22.98s
save result PNG image to './teste_1760121748_2.png'

@LostRuins
Copy link
Contributor

LostRuins commented Oct 11, 2025

Hello @leejet , sorry for the delay
I am using RTX4090 on Windows. My build was built by github CI, you can see the exact commit from my stable-diffusion.cpp fork here
The only difference is changing the CMAKE_CUDA_ARCHITECTURES to work with my RTX4090.

My params are very simple and now fully reproducible every time.
sd.exe --diffusion-model qwen-image-Q4_K_S.gguf --qwen2vl Qwen2.5-VL-7B-Instruct-Q4_K_M.gguf --vae qwen_image_vae.safetensors -p cat

Here is the terminal log:

D:\sdcpp>sd.exe --diffusion-model qwen-image-Q4_K_S.gguf --qwen2vl Qwen2.5-VL-7B-Instruct-Q4_K_M.gguf --vae qwen_image_vae.safetensors -p cat
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: found 1 CUDA devices:
[INFO ] ggml_extend.hpp:65   -   Device 0: NVIDIA GeForce RTX 4090 Laptop GPU, compute capability 8.9, VMM: yes
[INFO ] stable-diffusion.cpp:209  - loading diffusion model from 'qwen-image-Q4_K_S.gguf'
[INFO ] model.cpp:1072 - load qwen-image-Q4_K_S.gguf using gguf format
[INFO ] stable-diffusion.cpp:256  - loading qwen2vl from 'Qwen2.5-VL-7B-Instruct-Q4_K_M.gguf'
[INFO ] model.cpp:1072 - load Qwen2.5-VL-7B-Instruct-Q4_K_M.gguf using gguf format
[INFO ] stable-diffusion.cpp:263  - loading vae from 'qwen_image_vae.safetensors'
[INFO ] model.cpp:1075 - load qwen_image_vae.safetensors using safetensors format
[INFO ] stable-diffusion.cpp:275  - Version: Qwen Image
[INFO ] stable-diffusion.cpp:306  - Weight type:                 bf16
[INFO ] stable-diffusion.cpp:307  - Conditioner weight type:     f32
[INFO ] stable-diffusion.cpp:308  - Diffusion model weight type: bf16
[INFO ] stable-diffusion.cpp:309  - VAE weight type:             NONE
  |=======================================>          | 1933/2465 - 590.77it/s←[K
  |==============================================>   | 2271/2465 - 442.35it/s←[K
  |==============================================>   | 2271/2465 - 442.26it/s←[K[INFO ] model.cpp:2353 - unknown tensor 'first_stage_model.conv1.weight | bf16 | 4 [1, 1, 1, 1024, 1]' in model file
[INFO ] model.cpp:2353 - unknown tensor 'first_stage_model.conv1.weight | bf16 | 4 [1, 1, 1, 1024, 1]' in model file
  |==================================================| 2465/2465 - 461.26it/s←[K
[INFO ] model.cpp:2327 - loading tensors completed, taking 5.35s (process: 0.01s, read: 3.31s, memcpy: 0.00s, convert: 0.06s, copy_to_backend: 0.93s)
[INFO ] stable-diffusion.cpp:679  - total params memory size = 17538.62MB (VRAM 17538.62MB, RAM 0.00MB): text_encoders 5820.73MB(VRAM), diffusion_model 11578.05MB(VRAM), vae 139.84MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)
[INFO ] stable-diffusion.cpp:726  - running in FLOW mode
[INFO ] stable-diffusion.cpp:2441 - TXT2IMG
init (f32): shape(64, 64, 16, 1)
[INFO ] stable-diffusion.cpp:899  - attempting to apply 0 LoRAs
[INFO ] stable-diffusion.cpp:919  - apply_loras completed, taking 0.00s
[INFO ] stable-diffusion.cpp:2111 - get_learned_condition completed, taking 150 ms
[INFO ] stable-diffusion.cpp:2136 - sampling using Euler method
[INFO ] stable-diffusion.cpp:2185 - generating image: 1/1 - seed 42
  |==================================================| 20/20 - 1.34it/s←[K
[INFO ] stable-diffusion.cpp:2222 - sampling completed, taking 16.16s
[INFO ] stable-diffusion.cpp:2230 - generating 1 latent images completed, taking 16.39s
[INFO ] stable-diffusion.cpp:2233 - decoding 1 latents
[INFO ] stable-diffusion.cpp:2243 - latent 1 decoded, taking 0.41s
[INFO ] stable-diffusion.cpp:2247 - decode_first_stage completed, taking 0.41s
[INFO ] stable-diffusion.cpp:2517 - generate_image completed in 16.96s
save result PNG image to 'output.png'

And here is the output, which is a black square:
output

Only cuda is broken. CPU and Vulkan work fine.

@phil2sat
Copy link

phil2sat commented Oct 11, 2025

with Rocm i get also black image BUT the first step does an image. so its something with sampling i guess. maybe on cuda also since Rocm is similar to cuda or better part from cuda.

So maybe try to generate an 1-step image with cuda and see if its not black

@leejet
Copy link
Owner Author

leejet commented Oct 11, 2025

I’m also using a 4090. When I use CUDA, image generation works normally, but I’m using Q8 quantization.

Option: 
    n_threads:                         8
    mode:                              img_gen
    model_path:
    wtype:                             unspecified
    clip_l_path:
    clip_g_path:
    clip_vision_path:
    t5xxl_path:
    qwen2vl_path:                      ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf
    diffusion_model_path:              ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf
    high_noise_diffusion_model_path:
    vae_path:                          ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors
    taesd_path:
    esrgan_path:
    control_net_path:
    embedding_dir:
    photo_maker_path:
    pm_id_images_dir:
    pm_id_embed_path:
    pm_style_strength:                 20.00
    output_path:                       output.png
    init_image_path:
    end_image_path:
    mask_image_path:
    control_image_path:
    ref_images_paths:
    control_video_path:
    increase_ref_index:                false
    offload_params_to_cpu:             true
    clip_on_cpu:                       false
    control_net_cpu:                   false
    vae_on_cpu:                        false
    diffusion flash attention:         false
    diffusion Conv2d direct:           false
    vae_conv_direct:                   false
    control_strength:                  0.90
    prompt:                            cat
    negative_prompt:
    clip_skip:                         -1
    width:                             512
    height:                            512
    sample_params:                     (txt_cfg: 7.00, img_cfg: 7.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: default, sample_steps: 20, eta: 0.00, shifted_timestep: 0)
    high_noise_sample_params:          (txt_cfg: 7.00, img_cfg: 7.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: default, sample_steps: -1, eta: 0.00, shifted_timestep: 0)
    moe_boundary:                      0.875
    flow_shift:                        inf
    strength(img2img):                 0.75
    rng:                               cuda
    seed:                              42
    batch_count:                       1
    vae_tiling:                        false
    upscale_repeats:                   1
    chroma_use_dit_mask:               true
    chroma_use_t5_mask:                false
    chroma_t5_mask_pad:                1
    video_frames:                      1
    vace_strength:                     1.00
    fps:                               16
System Info:
    SSE3 = 1
    AVX = 1
    AVX2 = 1
    AVX512 = 0
    AVX512_VBMI = 0
    AVX512_VNNI = 0
    FMA = 1
    NEON = 0
    ARM_FMA = 0
    F16C = 1
    FP16_VA = 0
    WASM_SIMD = 0
    VSX = 0
[DEBUG] stable-diffusion.cpp:147  - Using CUDA backend
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:65   - ggml_cuda_init: found 1 CUDA devices:
[INFO ] ggml_extend.hpp:65   -   Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
[INFO ] stable-diffusion.cpp:211  - loading diffusion model from '..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf'
[INFO ] model.cpp:1072 - load ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf using gguf format
[DEBUG] model.cpp:1089 - init from '..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf'
[INFO ] stable-diffusion.cpp:258  - loading qwen2vl from '..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf'
[INFO ] model.cpp:1072 - load ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf using gguf format
[DEBUG] model.cpp:1089 - init from '..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf'
[INFO ] stable-diffusion.cpp:265  - loading vae from '..\..\ComfyUI\models\vae\qwen_image_vae.safetensors'
[INFO ] model.cpp:1075 - load ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors using safetensors format
[DEBUG] model.cpp:1182 - init from '..\..\ComfyUI\models\vae\qwen_image_vae.safetensors', prefix = 'vae.'
[INFO ] stable-diffusion.cpp:277  - Version: Qwen Image 
[INFO ] stable-diffusion.cpp:308  - Weight type:                 bf16
[INFO ] stable-diffusion.cpp:309  - Conditioner weight type:     f32
[INFO ] stable-diffusion.cpp:310  - Diffusion model weight type: bf16
[INFO ] stable-diffusion.cpp:311  - VAE weight type:             NONE
[DEBUG] stable-diffusion.cpp:313  - ggml tensor size = 400 bytes
[DEBUG] qwenvl.hpp:137  - merges size 151387
[DEBUG] qwenvl.hpp:159  - vocab size: 151665
[DEBUG] ggml_extend.hpp:1752 - qwenvl2.5 params backend buffer size =  7165.44 MB(RAM) (338 tensors)
[DEBUG] ggml_extend.hpp:1752 - qwen_image params backend buffer size =  20753.54 MB(RAM) (1933 tensors)
[DEBUG] ggml_extend.hpp:1752 - wan_vae params backend buffer size =  139.84 MB(RAM) (108 tensors)
[DEBUG] stable-diffusion.cpp:590  - loading weights
[DEBUG] model.cpp:2005 - using 8 threads for model loading
[DEBUG] model.cpp:2088 - loading tensors from ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf
  |=======================================>          | 1933/2465 - 241.29it/s
[DEBUG] model.cpp:2088 - loading tensors from ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf
  |==============================================>   | 2271/2465 - 210.34it/s
[DEBUG] model.cpp:2088 - loading tensors from ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors
  |==============================================>   | 2271/2465 - 210.34it/s[INFO ] model.cpp:2358 - unknown tensor 'first_stage_model.conv1.weight | bf16 | 4 [1, 1, 1, 1024, 1]' in model file
[INFO ] model.cpp:2358 - unknown tensor 'first_stage_model.conv1.weight | bf16 | 4 [1, 1, 1, 1024, 1]' in model file
  |==================================================| 2465/2465 - 224.09it/s
[INFO ] model.cpp:2332 - loading tensors completed, taking 11.02s (process: 0.02s, read: 10.64s, memcpy: 0.00s, convert: 0.05s, copy_to_backend: 0.00s)
[INFO ] stable-diffusion.cpp:686  - total params memory size = 28058.83MB (VRAM 28058.83MB, RAM 0.00MB): text_encoders 7165.44MB(VRAM), diffusion_model 20753.55MB(VRAM), vae 139.84MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)
[INFO ] stable-diffusion.cpp:733  - running in FLOW mode
[DEBUG] stable-diffusion.cpp:757  - finished loaded file
[DEBUG] stable-diffusion.cpp:2380 - generate_image 512x512
[INFO ] stable-diffusion.cpp:2507 - TXT2IMG
init (f32): shape(64, 64, 16, 1)
[INFO ] stable-diffusion.cpp:906  - attempting to apply 0 LoRAs
[INFO ] stable-diffusion.cpp:926  - apply_loras completed, taking 0.00s
[DEBUG] stable-diffusion.cpp:927  - prompt after extract and remove lora: "cat"
[DEBUG] conditioner.hpp:1416 - parse '<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
cat<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
cat<|im_end|>
<|im_start|>assistant
', 1], ]
[INFO ] ggml_extend.hpp:1676 - qwenvl2.5 offload params (7165.44 MB, 338 tensors) to runtime backend (CUDA0), taking 1.31s
[DEBUG] ggml_extend.hpp:1576 - qwenvl2.5 compute buffer size: 7.42 MB(VRAM)
[DEBUG] conditioner.hpp:1486 - computing condition graph completed, taking 1541 ms
[DEBUG] conditioner.hpp:1416 - parse '<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
<|im_end|>
<|im_start|>assistant
', 1], ]
[INFO ] ggml_extend.hpp:1676 - qwenvl2.5 offload params (7165.44 MB, 338 tensors) to runtime backend (CUDA0), taking 1.16s
[DEBUG] ggml_extend.hpp:1576 - qwenvl2.5 compute buffer size: 7.24 MB(VRAM)
[DEBUG] conditioner.hpp:1486 - computing condition graph completed, taking 1236 ms
[INFO ] stable-diffusion.cpp:2118 - get_learned_condition completed, taking 2782 ms
[INFO ] stable-diffusion.cpp:2143 - sampling using Euler method
[INFO ] stable-diffusion.cpp:2237 - generating image: 1/1 - seed 42
[INFO ] ggml_extend.hpp:1676 - qwen_image offload params (20753.54 MB, 1933 tensors) to runtime backend (CUDA0), taking 3.40s
[DEBUG] ggml_extend.hpp:1576 - qwen_image compute buffer size: 183.00 MB(VRAM)
  |==================================================| 20/20 - 2.50it/s
[INFO ] stable-diffusion.cpp:2274 - sampling completed, taking 11.80s
[INFO ] stable-diffusion.cpp:2282 - generating 1 latent images completed, taking 13.75s
[INFO ] stable-diffusion.cpp:2285 - decoding 1 latents
[INFO ] ggml_extend.hpp:1676 - wan_vae offload params (139.84 MB, 108 tensors) to runtime backend (CUDA0), taking 0.03s
[DEBUG] ggml_extend.hpp:1576 - wan_vae compute buffer size: 1874.50 MB(VRAM)
[DEBUG] stable-diffusion.cpp:1612 - computing vae decode graph completed, taking 0.31s
[INFO ] stable-diffusion.cpp:2295 - latent 1 decoded, taking 0.31s
[INFO ] stable-diffusion.cpp:2299 - decode_first_stage completed, taking 0.31s
[INFO ] stable-diffusion.cpp:2583 - generate_image completed in 16.85s
save result PNG image to 'output.png'
output

@leejet
Copy link
Owner Author

leejet commented Oct 11, 2025

I tried the q4_K_S quantization and reproduced the black image issue. It’s likely due to a precision problem in the CUDA computations related to q4_K_S.

@wbruna
Copy link
Contributor

wbruna commented Oct 11, 2025

The full q8_0 model won't fit on my card, but the Pruning 13b q8_0 works on ROCm, too.

@LostRuins
Copy link
Contributor

LostRuins commented Oct 11, 2025

i wonder what ops it could be besides mat mul, since it works on vulkan with q4ks

@jeffbolznv
Copy link

I had to implement get_rows for this, I think cuda may also be missing this.

BTW, its a shame that sd is still using the ggml path that doesn't automatically fallback on unsupported ops.

@leejet
Copy link
Owner Author

leejet commented Oct 11, 2025

In sd.cpp, if a type isn’t one of {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}, it’ll first get converted to GGML_TYPE_F32 on the CPU before calling get_rows.

@LostRuins
Copy link
Contributor

LostRuins commented Oct 12, 2025

@leejet @jeffbolznv I tried using q4_0 instead of q4_k_s and it worked on cuda!!!
very strange very very strange.
but at least its a clue

@leejet
Copy link
Owner Author

leejet commented Oct 12, 2025

@leejet @jeffbolznv I tried using q4_0 instead of q4_k_s and it worked on cuda!!! very strange very very strange. but at least its a clue

I tried using q2_k as well, and it also generated a black image. Maybe all the k-quants have this issue.

@LostRuins
Copy link
Contributor

LostRuins commented Oct 12, 2025

If we can narrow it down to a specific offending tensor, maybe we can force convert that to a different type.

edit2: it fails when i load a "q4_0" model, but it works when i manually set the wtype to convert to q4.

@leejet
Copy link
Owner Author

leejet commented Oct 12, 2025

I’ve located and fixed the issue. It’s working fine on my side now — you can test it again on your end @LostRuins @wbruna.

q2_k with cuda

.\bin\Release\sd.exe --diffusion-model  ..\..\ComfyUI\models\diffusion_models\qwen-image-Q2_K.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors  --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf  -p "cat" -v --offload-to-cpu
output

common.hpp Outdated
Comment on lines 267 to 273
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using Vulkan without enabling force_prec_f32,
// or when using CUDA but the weights are k-quants.
float scale = 1.f / 128.f;
x = ggml_scale(ctx, x, scale);
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
x = ggml_scale(ctx, x, 1.f / scale);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious which part in the CUDA backend causes the issue here? I assume you are working around some FP overflow?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s likely that ggml_mul_mat has a precision issue when the weights are k-quants.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why did Jeff's ggml_mul_mat_set_prec fix work for vulkan but not cuda, could cuda be ignoring that?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cuda approach to matmul is pretty different (see #851 (comment)). Anecdotally it seems to be less prone to precision issues, but I guess it can still run into problems.

@LostRuins
Copy link
Contributor

Alright i'm building all your latest changes on cuda and will let you know how it goes

@wbruna
Copy link
Contributor

wbruna commented Oct 12, 2025

@leejet , tested d21d1aa : working on ROCm for Q4_K_M, too!

(something unrelated seemingly broke the Pruning support, but that can be discussed on #874 )- (edit: working as intended on the qwen_image_edit branch)

@leejet
Copy link
Owner Author

leejet commented Oct 12, 2025

Alright i'm building all your latest changes on cuda and will let you know how it goes

@LostRuins Are there still any issues on your side during testing?

@LostRuins
Copy link
Contributor

My previous build failed due to an unrelated issue, so I am rebuilding it again. It takes me 1 hour for each CUDA build lol.

@leejet
Copy link
Owner Author

leejet commented Oct 12, 2025

Yeah… building ggml with CUDA always takes a lot of time.

@LostRuins
Copy link
Contributor

@leejet seems to be working well now, thanks!

image

@leejet
Copy link
Owner Author

leejet commented Oct 12, 2025

It looks like this PR can be merged now. Thanks, everyone!

@leejet leejet merged commit beb99a2 into master Oct 12, 2025
8 checks passed
@phil2sat
Copy link

phil2sat commented Oct 13, 2025

Can confirm all fine for now, even with clip on gpu and on my old Mi25:
output
Nice work... Thanks

Details

build/bin/sd --diffusion-model /daten/models/unet/Qwen_Image_Distill-Q4_0.gguf --vae /daten/models/vae/qwen_image_vae.safetensors --qwen2vl /daten/models/text_encoders/Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p 'a comic hero holding a sign "Qwen-Image now also works on Rocm"' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 512 -W 512 --flow-shift 3 -s 42 --steps 10 --offload-to-cpu
Option:
n_threads: 4
mode: img_gen
model_path:
wtype: unspecified
clip_l_path:
clip_g_path:
clip_vision_path:
t5xxl_path:
qwen2vl_path: /daten/models/text_encoders/Qwen2.5-VL-7B-Instruct-Q8_0.gguf
diffusion_model_path: /daten/models/unet/Qwen_Image_Distill-Q4_0.gguf
high_noise_diffusion_model_path:
vae_path: /daten/models/vae/qwen_image_vae.safetensors
taesd_path:
esrgan_path:
control_net_path:
embedding_dir:
photo_maker_path:
pm_id_images_dir:
pm_id_embed_path:
pm_style_strength: 20.00
output_path: output.png
init_image_path:
end_image_path:
mask_image_path:
control_image_path:
ref_images_paths:
control_video_path:
increase_ref_index: false
offload_params_to_cpu: true
clip_on_cpu: false
control_net_cpu: false
vae_on_cpu: false
diffusion flash attention: false
diffusion Conv2d direct: false
vae_conv_direct: false
control_strength: 0.90
prompt: a comic hero holding a sign "Qwen-Image now also works on Rocm"
negative_prompt:
clip_skip: -1
width: 512
height: 512
sample_params: (txt_cfg: 2.50, img_cfg: 2.50, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: euler, sample_steps: 10, eta: 0.00, shifted_timestep: 0)
high_noise_sample_params: (txt_cfg: 7.00, img_cfg: 7.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: default, sample_method: default, sample_steps: -1, eta: 0.00, shifted_timestep: 0)
moe_boundary: 0.875
flow_shift: 3.00
strength(img2img): 0.75
rng: cuda
seed: 42
batch_count: 1
vae_tiling: false
upscale_repeats: 1
chroma_use_dit_mask: true
chroma_use_t5_mask: false
chroma_t5_mask_pad: 1
video_frames: 1
vace_strength: 1.00
fps: 16
System Info:
SSE3 = 1
AVX = 1
AVX2 = 1
AVX512 = 0
AVX512_VBMI = 0
AVX512_VNNI = 0
FMA = 1
NEON = 0
ARM_FMA = 0
F16C = 1
FP16_VA = 0
WASM_SIMD = 0
VSX = 0
[DEBUG] stable-diffusion.cpp:147 - Using CUDA backend
[INFO ] ggml_extend.hpp:69 - ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
[INFO ] ggml_extend.hpp:69 - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:69 - ggml_cuda_init: found 1 ROCm devices:
[INFO ] ggml_extend.hpp:69 - Device 0: AMD Radeon Pro WX 9100, gfx900:xnack- (0x900), VMM: no, Wave Size: 64
[INFO ] stable-diffusion.cpp:211 - loading diffusion model from '/daten/models/unet/Qwen_Image_Distill-Q4_0.gguf'
[INFO ] model.cpp:1072 - load /daten/models/unet/Qwen_Image_Distill-Q4_0.gguf using gguf format
[DEBUG] model.cpp:1089 - init from '/daten/models/unet/Qwen_Image_Distill-Q4_0.gguf'
[INFO ] stable-diffusion.cpp:258 - loading qwen2vl from '/daten/models/text_encoders/Qwen2.5-VL-7B-Instruct-Q8_0.gguf'
[INFO ] model.cpp:1072 - load /daten/models/text_encoders/Qwen2.5-VL-7B-Instruct-Q8_0.gguf using gguf format
[DEBUG] model.cpp:1089 - init from '/daten/models/text_encoders/Qwen2.5-VL-7B-Instruct-Q8_0.gguf'
[INFO ] stable-diffusion.cpp:265 - loading vae from '/daten/models/vae/qwen_image_vae.safetensors'
[INFO ] model.cpp:1075 - load /daten/models/vae/qwen_image_vae.safetensors using safetensors format
[DEBUG] model.cpp:1182 - init from '/daten/models/vae/qwen_image_vae.safetensors', prefix = 'vae.'
[INFO ] stable-diffusion.cpp:277 - Version: Qwen Image
[INFO ] stable-diffusion.cpp:308 - Weight type: bf16
[INFO ] stable-diffusion.cpp:309 - Conditioner weight type: f32
[INFO ] stable-diffusion.cpp:310 - Diffusion model weight type: bf16
[INFO ] stable-diffusion.cpp:311 - VAE weight type: NONE
[DEBUG] stable-diffusion.cpp:313 - ggml tensor size = 400 bytes
[DEBUG] qwenvl.hpp:137 - merges size 151387
[DEBUG] qwenvl.hpp:159 - vocab size: 151665
[DEBUG] ggml_extend.hpp:1769 - qwenvl2.5 params backend buffer size = 7165.44 MB(RAM) (338 tensors)
[DEBUG] ggml_extend.hpp:1769 - qwen_image params backend buffer size = 11303.54 MB(RAM) (1933 tensors)
[DEBUG] ggml_extend.hpp:1769 - wan_vae params backend buffer size = 139.84 MB(RAM) (108 tensors)
[DEBUG] stable-diffusion.cpp:580 - loading weights
[DEBUG] model.cpp:2005 - using 4 threads for model loading
[DEBUG] model.cpp:2088 - loading tensors from /daten/models/unet/Qwen_Image_Distill-Q4_0.gguf
|=======================================> | 1933/2465 - 18.14it/s
[DEBUG] model.cpp:2088 - loading tensors from /daten/models/text_encoders/Qwen2.5-VL-7B-Instruct-Q8_0.gguf
|==============================================> | 2271/2465 - 12.68it/s
[DEBUG] model.cpp:2088 - loading tensors from /daten/models/vae/qwen_image_vae.safetensors
|==============================================> | 2274/2465 - 12.70it/s[INFO ] model.cpp:2358 - unknown tensor 'first_stage_model.conv1.bias | bf16 | 1 [32, 1, 1, 1, 1]' in model file
[INFO ] model.cpp:2358 - unknown tensor 'first_stage_model.conv1.weight | bf16 | 4 [1, 1, 1, 1024, 1]' in model file
|==================================================| 2465/2465 - 13.62it/s
[INFO ] model.cpp:2332 - loading tensors completed, taking 181.03s (process: 0.01s, read: 180.06s, memcpy: 0.00s, convert: 0.03s, copy_to_backend: 0.00s)
[INFO ] stable-diffusion.cpp:676 - total params memory size = 18608.83MB (VRAM 18608.83MB, RAM 0.00MB): text_encoders 7165.44MB(VRAM), diffusion_model 11303.55MB(VRAM), vae 139.84MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)
[INFO ] stable-diffusion.cpp:723 - running in FLOW mode
[DEBUG] stable-diffusion.cpp:747 - finished loaded file
[DEBUG] stable-diffusion.cpp:2383 - generate_image 512x512
[INFO ] stable-diffusion.cpp:2510 - TXT2IMG
[INFO ] stable-diffusion.cpp:896 - attempting to apply 0 LoRAs
[INFO ] stable-diffusion.cpp:916 - apply_loras completed, taking 0.00s
[DEBUG] stable-diffusion.cpp:917 - prompt after extract and remove lora: "a comic hero holding a sign "Qwen-Image now also works on Rocm""
[DEBUG] conditioner.hpp:1416 - parse '<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
a comic hero holding a sign "Qwen-Image now also works on Rocm"<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
a comic hero holding a sign "Qwen-Image now also works on Rocm"<|im_end|>
<|im_start|>assistant
', 1], ]
[INFO ] ggml_extend.hpp:1693 - qwenvl2.5 offload params (7165.44 MB, 338 tensors) to runtime backend (ROCm0), taking 1.81s
[DEBUG] ggml_extend.hpp:1593 - qwenvl2.5 compute buffer size: 10.58 MB(VRAM)
[DEBUG] conditioner.hpp:1486 - computing condition graph completed, taking 2592 ms
[DEBUG] conditioner.hpp:1416 - parse '<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>system
Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
<|im_end|>
<|im_start|>assistant
', 1], ]
[INFO ] ggml_extend.hpp:1693 - qwenvl2.5 offload params (7165.44 MB, 338 tensors) to runtime backend (ROCm0), taking 1.10s
[DEBUG] ggml_extend.hpp:1593 - qwenvl2.5 compute buffer size: 7.24 MB(VRAM)
[DEBUG] conditioner.hpp:1486 - computing condition graph completed, taking 1390 ms
[INFO ] stable-diffusion.cpp:2121 - get_learned_condition completed, taking 4024 ms
[INFO ] stable-diffusion.cpp:2146 - sampling using Euler method
[INFO ] stable-diffusion.cpp:2240 - generating image: 1/1 - seed 42
[INFO ] ggml_extend.hpp:1693 - qwen_image offload params (11303.54 MB, 1933 tensors) to runtime backend (ROCm0), taking 6.24s
[DEBUG] ggml_extend.hpp:1593 - qwen_image compute buffer size: 187.55 MB(VRAM)
|==================================================| 10/10 - 7.73s/it
[INFO ] stable-diffusion.cpp:2277 - sampling completed, taking 83.66s
[INFO ] stable-diffusion.cpp:2285 - generating 1 latent images completed, taking 84.08s
[INFO ] stable-diffusion.cpp:2288 - decoding 1 latents
[INFO ] ggml_extend.hpp:1693 - wan_vae offload params (139.84 MB, 108 tensors) to runtime backend (ROCm0), taking 0.06s
[DEBUG] ggml_extend.hpp:1593 - wan_vae compute buffer size: 1874.50 MB(VRAM)
[DEBUG] stable-diffusion.cpp:1611 - computing vae decode graph completed, taking 1.66s
[INFO ] stable-diffusion.cpp:2298 - latent 1 decoded, taking 1.66s
[INFO ] stable-diffusion.cpp:2302 - decode_first_stage completed, taking 1.66s
[INFO ] stable-diffusion.cpp:2586 - generate_image completed in 89.76s
save result PNG image to 'output.png'

But similar as before with Flux and T5xxl
using GGML_CUDA_FORCE_CUBLAS: yes, black image no matter what i try clip or vae on cpu doesn't change anything.

For now im testing with GGML_CUDA_FORCE_CUBLAS: no, but for all other models it makes generation 40% slower

Since Flux and t5xxl is now fine with GGML_CUDA_FORCE_CUBLAS: yes maybe something similar happens here, maybe the unet itself!?

@evcharger
Copy link

evcharger commented Oct 13, 2025

Just a shoutout to anyone wanting to try this out with low-powered hardware - I am successfully running this on a 2019 Ryzen 5 3400G with an iGPU (Vega RX 11) with 16 GB of Ram allocated for VRAM and 24 GB of GTT memory under Ubuntu (on a total of 64 GB system ram) - on the Q4_0 quant I am generating the same image of a cat as above at 512 px for 504s. (as opposed to the 16.85s. achieved with the 4090 above). Using the Q8_0 quant at 1024 px, things become very slow - All 16 GB Vram used + 8 GB GTT and another 21 GB offloaded to system ram, so a total of ~ 45 GB memory used and around 330s/it. Thank you @leejet ! edit - it's running on Vulkan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants