|
5 | 5 | import random
|
6 | 6 | import spaces
|
7 | 7 |
|
8 |
| - |
9 | 8 | from OmniGen import OmniGenPipeline
|
10 | 9 |
|
11 |
| -pipe = OmniGenPipeline.from_pretrained( |
12 |
| - "Shitao/OmniGen-v1" |
13 |
| -) |
14 | 10 |
|
15 | 11 | @spaces.GPU(duration=180)
|
16 | 12 | def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
|
@@ -370,6 +366,8 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_
|
370 | 366 |
|
371 | 367 | with gr.Column():
|
372 | 368 | with gr.Column():
|
| 369 | + # quantization = gr.Radio(["4bit (NF4)", "8bit", "None (bf16)"], label="bitsandbytes quantization", value="4bit (NF4)") |
| 370 | + # quantization.input(change_quantization, inputs=quantization, trigger_mode="once", concurrency_limit=1) |
373 | 371 | # output image
|
374 | 372 | output_image = gr.Image(label="Output Image")
|
375 | 373 | save_images = gr.Checkbox(label="Save generated images", value=False)
|
@@ -425,7 +423,21 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_
|
425 | 423 | if __name__ == "__main__":
|
426 | 424 | parser = argparse.ArgumentParser(description='Run the OmniGen')
|
427 | 425 | parser.add_argument('--share', action='store_true', help='Share the Gradio app')
|
| 426 | + parser.add_argument('-b', '--nbits', choices=['4','8'], help='bitsandbytes quantization n-bits') |
428 | 427 | args = parser.parse_args()
|
429 | 428 |
|
| 429 | + if args.nbits == '4': |
| 430 | + quantization_config = 'bnb_4bit' |
| 431 | + elif args.nbits == '8': |
| 432 | + quantization_config = 'bnb_8bit' |
| 433 | + else: |
| 434 | + quantization_config = None |
| 435 | + |
| 436 | + pipe = OmniGenPipeline.from_pretrained( |
| 437 | + "Shitao/OmniGen-v1", |
| 438 | + quantization_config = quantization_config, |
| 439 | + low_cpu_mem_usage=True, |
| 440 | + ) |
| 441 | + |
430 | 442 | # launch
|
431 | 443 | demo.launch(share=args.share)
|
0 commit comments