Skip to content

Commit a01e6eb

Browse files
committed
Add support for profiling via setup(enable_profiling=True).
This generates a chrome-profile.json file during the prediction that you can load with chrome devtools to inspect CPU and GPU usage.
1 parent e75132a commit a01e6eb

File tree

1 file changed

+57
-30
lines changed

1 file changed

+57
-30
lines changed

predict.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import os
23
import time
34
from typing import Any, Tuple, Optional
@@ -136,8 +137,10 @@ def base_setup(
136137
compile_fp8: bool = False,
137138
compile_bf16: bool = False,
138139
disable_fp8: bool = False,
140+
enable_profiling: bool = False,
139141
) -> None:
140142
self.flow_model_name = flow_model_name
143+
self.enable_profiling = enable_profiling
141144
print(f"Booting model {self.flow_model_name}")
142145

143146
gpu_name = (
@@ -477,6 +480,7 @@ def postprocess(
477480
output_format: str,
478481
output_quality: int,
479482
np_images: Optional[List[Image]] = None,
483+
profile: Optional[Path] = None,
480484
) -> List[Path]:
481485
has_nsfw_content = [False] * len(images)
482486

@@ -513,7 +517,10 @@ def postprocess(
513517
)
514518

515519
print(f"Total safe images: {len(output_paths)} out of {len(images)}")
516-
return output_paths
520+
if profile:
521+
return [profile, output_paths]
522+
else:
523+
return output_paths
517524

518525
def run_safety_checker(self, images, np_images):
519526
safety_checker_input = self.feature_extractor(images, return_tensors="pt").to(
@@ -547,32 +554,48 @@ def shared_predict(
547554
seed: int = None,
548555
width: int = 1024,
549556
height: int = 1024,
550-
):
551-
if go_fast and not self.disable_fp8:
552-
return self.fp8_predict(
553-
prompt=prompt,
554-
num_outputs=num_outputs,
555-
num_inference_steps=num_inference_steps,
556-
guidance=guidance,
557-
image=image,
558-
prompt_strength=prompt_strength,
559-
seed=seed,
560-
width=width,
561-
height=height,
557+
) -> Tuple[List[Image.Image], Optional[List[np.ndarray]], Optional[Path]]:
558+
if self.enable_profiling:
559+
profiler = torch.profiler.profile(
560+
activities=[
561+
torch.profiler.ProfilerActivity.CPU,
562+
torch.profiler.ProfilerActivity.CUDA,
563+
]
562564
)
563-
if self.disable_fp8:
564-
print("running bf16 model, fp8 disabled")
565-
return self.base_predict(
566-
prompt=prompt,
567-
num_outputs=num_outputs,
568-
num_inference_steps=num_inference_steps,
569-
guidance=guidance,
570-
image=image,
571-
prompt_strength=prompt_strength,
572-
seed=seed,
573-
width=width,
574-
height=height,
575-
)
565+
else:
566+
profiler = contextlib.nullcontext()
567+
568+
with profiler:
569+
if go_fast and not self.disable_fp8:
570+
imgs, np_imgs = self.fp8_predict(
571+
prompt=prompt,
572+
num_outputs=num_outputs,
573+
num_inference_steps=num_inference_steps,
574+
guidance=guidance,
575+
image=image,
576+
prompt_strength=prompt_strength,
577+
seed=seed,
578+
width=width,
579+
height=height,
580+
)
581+
else:
582+
if self.disable_fp8:
583+
print("running bf16 model, fp8 disabled")
584+
imgs, np_imgs = self.base_predict(
585+
prompt=prompt,
586+
num_outputs=num_outputs,
587+
num_inference_steps=num_inference_steps,
588+
guidance=guidance,
589+
image=image,
590+
prompt_strength=prompt_strength,
591+
seed=seed,
592+
width=width,
593+
height=height,
594+
)
595+
if isinstance(profiler, torch.profiler.profile):
596+
profiler.export_chrome_trace("chrome-trace.json")
597+
return imgs, np_imgs, Path("chrome-trace.json")
598+
return imgs, np_imgs, None
576599

577600

578601
class SchnellPredictor(Predictor):
@@ -598,7 +621,7 @@ def predict(
598621
megapixels: str = SHARED_INPUTS.megapixels,
599622
) -> List[Path]:
600623
width, height = self.preprocess(aspect_ratio, megapixels)
601-
imgs, np_imgs = self.shared_predict(
624+
imgs, np_imgs, profile = self.shared_predict(
602625
go_fast,
603626
prompt,
604627
num_outputs,
@@ -614,6 +637,7 @@ def predict(
614637
output_format,
615638
output_quality,
616639
np_images=np_imgs,
640+
profile=profile,
617641
)
618642

619643

@@ -656,7 +680,7 @@ def predict(
656680
print("img2img not supported with fp8 quantization; running with bf16")
657681
go_fast = False
658682
width, height = self.preprocess(aspect_ratio, megapixels)
659-
imgs, np_imgs = self.shared_predict(
683+
imgs, np_imgs, profile = self.shared_predict(
660684
go_fast,
661685
prompt,
662686
num_outputs,
@@ -675,6 +699,7 @@ def predict(
675699
output_format,
676700
output_quality,
677701
np_images=np_imgs,
702+
profile=profile,
678703
)
679704

680705

@@ -706,7 +731,7 @@ def predict(
706731
self.handle_loras(go_fast, lora_weights, lora_scale)
707732

708733
width, height = self.preprocess(aspect_ratio, megapixels)
709-
imgs, np_imgs = self.shared_predict(
734+
imgs, np_imgs, profile = self.shared_predict(
710735
go_fast,
711736
prompt,
712737
num_outputs,
@@ -722,6 +747,7 @@ def predict(
722747
output_format,
723748
output_quality,
724749
np_images=np_imgs,
750+
profile=profile,
725751
)
726752

727753

@@ -770,7 +796,7 @@ def predict(
770796
self.handle_loras(go_fast, lora_weights, lora_scale)
771797

772798
width, height = self.preprocess(aspect_ratio, megapixels)
773-
imgs, np_imgs = self.shared_predict(
799+
imgs, np_imgs, profile = self.shared_predict(
774800
go_fast,
775801
prompt,
776802
num_outputs,
@@ -789,6 +815,7 @@ def predict(
789815
output_format,
790816
output_quality,
791817
np_images=np_imgs,
818+
profile=profile,
792819
)
793820

794821

0 commit comments

Comments
 (0)