1
+ import contextlib
1
2
import os
2
3
import time
3
4
from typing import Any , Tuple , Optional
@@ -136,8 +137,10 @@ def base_setup(
136
137
compile_fp8 : bool = False ,
137
138
compile_bf16 : bool = False ,
138
139
disable_fp8 : bool = False ,
140
+ enable_profiling : bool = False ,
139
141
) -> None :
140
142
self .flow_model_name = flow_model_name
143
+ self .enable_profiling = enable_profiling
141
144
print (f"Booting model { self .flow_model_name } " )
142
145
143
146
gpu_name = (
@@ -477,6 +480,7 @@ def postprocess(
477
480
output_format : str ,
478
481
output_quality : int ,
479
482
np_images : Optional [List [Image ]] = None ,
483
+ profile : Optional [Path ] = None ,
480
484
) -> List [Path ]:
481
485
has_nsfw_content = [False ] * len (images )
482
486
@@ -513,7 +517,10 @@ def postprocess(
513
517
)
514
518
515
519
print (f"Total safe images: { len (output_paths )} out of { len (images )} " )
516
- return output_paths
520
+ if profile :
521
+ return [profile , output_paths ]
522
+ else :
523
+ return output_paths
517
524
518
525
def run_safety_checker (self , images , np_images ):
519
526
safety_checker_input = self .feature_extractor (images , return_tensors = "pt" ).to (
@@ -547,32 +554,48 @@ def shared_predict(
547
554
seed : int = None ,
548
555
width : int = 1024 ,
549
556
height : int = 1024 ,
550
- ):
551
- if go_fast and not self .disable_fp8 :
552
- return self .fp8_predict (
553
- prompt = prompt ,
554
- num_outputs = num_outputs ,
555
- num_inference_steps = num_inference_steps ,
556
- guidance = guidance ,
557
- image = image ,
558
- prompt_strength = prompt_strength ,
559
- seed = seed ,
560
- width = width ,
561
- height = height ,
557
+ ) -> Tuple [List [Image .Image ], Optional [List [np .ndarray ]], Optional [Path ]]:
558
+ if self .enable_profiling :
559
+ profiler = torch .profiler .profile (
560
+ activities = [
561
+ torch .profiler .ProfilerActivity .CPU ,
562
+ torch .profiler .ProfilerActivity .CUDA ,
563
+ ]
562
564
)
563
- if self .disable_fp8 :
564
- print ("running bf16 model, fp8 disabled" )
565
- return self .base_predict (
566
- prompt = prompt ,
567
- num_outputs = num_outputs ,
568
- num_inference_steps = num_inference_steps ,
569
- guidance = guidance ,
570
- image = image ,
571
- prompt_strength = prompt_strength ,
572
- seed = seed ,
573
- width = width ,
574
- height = height ,
575
- )
565
+ else :
566
+ profiler = contextlib .nullcontext ()
567
+
568
+ with profiler :
569
+ if go_fast and not self .disable_fp8 :
570
+ imgs , np_imgs = self .fp8_predict (
571
+ prompt = prompt ,
572
+ num_outputs = num_outputs ,
573
+ num_inference_steps = num_inference_steps ,
574
+ guidance = guidance ,
575
+ image = image ,
576
+ prompt_strength = prompt_strength ,
577
+ seed = seed ,
578
+ width = width ,
579
+ height = height ,
580
+ )
581
+ else :
582
+ if self .disable_fp8 :
583
+ print ("running bf16 model, fp8 disabled" )
584
+ imgs , np_imgs = self .base_predict (
585
+ prompt = prompt ,
586
+ num_outputs = num_outputs ,
587
+ num_inference_steps = num_inference_steps ,
588
+ guidance = guidance ,
589
+ image = image ,
590
+ prompt_strength = prompt_strength ,
591
+ seed = seed ,
592
+ width = width ,
593
+ height = height ,
594
+ )
595
+ if isinstance (profiler , torch .profiler .profile ):
596
+ profiler .export_chrome_trace ("chrome-trace.json" )
597
+ return imgs , np_imgs , Path ("chrome-trace.json" )
598
+ return imgs , np_imgs , None
576
599
577
600
578
601
class SchnellPredictor (Predictor ):
@@ -598,7 +621,7 @@ def predict(
598
621
megapixels : str = SHARED_INPUTS .megapixels ,
599
622
) -> List [Path ]:
600
623
width , height = self .preprocess (aspect_ratio , megapixels )
601
- imgs , np_imgs = self .shared_predict (
624
+ imgs , np_imgs , profile = self .shared_predict (
602
625
go_fast ,
603
626
prompt ,
604
627
num_outputs ,
@@ -614,6 +637,7 @@ def predict(
614
637
output_format ,
615
638
output_quality ,
616
639
np_images = np_imgs ,
640
+ profile = profile ,
617
641
)
618
642
619
643
@@ -656,7 +680,7 @@ def predict(
656
680
print ("img2img not supported with fp8 quantization; running with bf16" )
657
681
go_fast = False
658
682
width , height = self .preprocess (aspect_ratio , megapixels )
659
- imgs , np_imgs = self .shared_predict (
683
+ imgs , np_imgs , profile = self .shared_predict (
660
684
go_fast ,
661
685
prompt ,
662
686
num_outputs ,
@@ -675,6 +699,7 @@ def predict(
675
699
output_format ,
676
700
output_quality ,
677
701
np_images = np_imgs ,
702
+ profile = profile ,
678
703
)
679
704
680
705
@@ -706,7 +731,7 @@ def predict(
706
731
self .handle_loras (go_fast , lora_weights , lora_scale )
707
732
708
733
width , height = self .preprocess (aspect_ratio , megapixels )
709
- imgs , np_imgs = self .shared_predict (
734
+ imgs , np_imgs , profile = self .shared_predict (
710
735
go_fast ,
711
736
prompt ,
712
737
num_outputs ,
@@ -722,6 +747,7 @@ def predict(
722
747
output_format ,
723
748
output_quality ,
724
749
np_images = np_imgs ,
750
+ profile = profile ,
725
751
)
726
752
727
753
@@ -770,7 +796,7 @@ def predict(
770
796
self .handle_loras (go_fast , lora_weights , lora_scale )
771
797
772
798
width , height = self .preprocess (aspect_ratio , megapixels )
773
- imgs , np_imgs = self .shared_predict (
799
+ imgs , np_imgs , profile = self .shared_predict (
774
800
go_fast ,
775
801
prompt ,
776
802
num_outputs ,
@@ -789,6 +815,7 @@ def predict(
789
815
output_format ,
790
816
output_quality ,
791
817
np_images = np_imgs ,
818
+ profile = profile ,
792
819
)
793
820
794
821
0 commit comments