Skip to content

Commit accf137

Browse files
committed
feat: add cli arg to gradio demo for nbit quantization
1 parent 99012bd commit accf137

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

app.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@
55
import random
66
import spaces
77

8-
98
from OmniGen import OmniGenPipeline
109

11-
pipe = OmniGenPipeline.from_pretrained(
12-
"Shitao/OmniGen-v1"
13-
)
1410

1511
@spaces.GPU(duration=180)
1612
def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
@@ -370,6 +366,8 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_
370366

371367
with gr.Column():
372368
with gr.Column():
369+
# quantization = gr.Radio(["4bit (NF4)", "8bit", "None (bf16)"], label="bitsandbytes quantization", value="4bit (NF4)")
370+
# quantization.input(change_quantization, inputs=quantization, trigger_mode="once", concurrency_limit=1)
373371
# output image
374372
output_image = gr.Image(label="Output Image")
375373
save_images = gr.Checkbox(label="Save generated images", value=False)
@@ -425,7 +423,21 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_
425423
if __name__ == "__main__":
426424
parser = argparse.ArgumentParser(description='Run the OmniGen')
427425
parser.add_argument('--share', action='store_true', help='Share the Gradio app')
426+
parser.add_argument('-b', '--nbits', choices=['4','8'], help='bitsandbytes quantization n-bits')
428427
args = parser.parse_args()
429428

429+
if args.nbits == '4':
430+
quantization_config = 'bnb_4bit'
431+
elif args.nbits == '8':
432+
quantization_config = 'bnb_8bit'
433+
else:
434+
quantization_config = None
435+
436+
pipe = OmniGenPipeline.from_pretrained(
437+
"Shitao/OmniGen-v1",
438+
quantization_config = quantization_config,
439+
low_cpu_mem_usage=True,
440+
)
441+
430442
# launch
431443
demo.launch(share=args.share)

0 commit comments

Comments
 (0)