Skip to content

Commit

Permalink
Update image saving and test
Browse files Browse the repository at this point in the history
  • Loading branch information
akhilg-nv committed Oct 15, 2024
1 parent 25ecf87 commit d9dd478
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 74 deletions.
9 changes: 4 additions & 5 deletions tripy/examples/diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ This example demonstrates how to implement a Stable Diffusion model using Tripy

It's broken up into three components:

1. `model.py` defines the model using `tripy.Module` and associated APIs.
1. `model.py` defines the model using `tripy.Module` and associated APIs. `clip_model.py`, `unet_model.py`, `vae_model.py` implement specific components of the diffusion model.
2. `weight_loader.py` loads weights from a HuggingFace checkpoint.
3. `example.py` runs the end-to-end example, taking input text as a command-line argument,
running inference, and then displaying the generated output.
3. `example.py` runs the end-to-end example, taking input text as a command-line argument, running inference, and then displaying the generated output.

The model is currently implemented in `float32`.
The model defaults to running in `float32`, but is recommended to run in `float16` by providing the `--fp16` flag if you have less than 20-24 GB of GPU memory (note that normalization layers will still run in `float32` to preserve accuracy).

## Running The Example

Expand All @@ -24,5 +23,5 @@ The model is currently implemented in `float32`.
2. Run the example:

```bash
python3 example.py --seed 0 --steps 50 --prompt "a beautiful photograph of Mt. Fuji during cherry blossom"
python3 example.py --seed 0 --steps 50 --prompt "a beautiful photograph of Mt. Fuji during cherry blossom" --fp16 --engine-dir fp16_engines
```
142 changes: 87 additions & 55 deletions tripy/examples/diffusion/example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -34,16 +34,15 @@
def compile_model(model, inputs, verbose=False):
if verbose:
name = model.__class__.__name__ if isinstance(model, tp.Module) else model.__name__
print(f"[I] Compiling {name}...", end=' ', flush=True)
print(f"[I] Compiling {name}...", end=" ", flush=True)
compile_start_time = time.perf_counter()

compiler = tp.Compiler(model)
compiled_model = compiler.compile(*inputs)
compiled_model = tp.compile(model, args=inputs)

if verbose:
compile_end_time = time.perf_counter()
print(f"took {compile_end_time - compile_start_time} seconds.")

return compiled_model


Expand Down Expand Up @@ -94,6 +93,28 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
return latent


def save_image(image, args):
if args.out:
filename = args.out
else:
filename = (
f"{'torch' if args.torch_inference else 'tp'}-"
f"{'fp16' if args.fp16 else 'fp32'}-"
f"{args.prompt[:10].replace(' ', '_')}-"
f"steps{args.steps}-"
f"seed{args.seed if args.seed else 'rand'}-"
f"{int(time.time())}.png"
)

target = os.path.join("output", filename)
# Save image
print(f"[I] Saving image to {target}")
if not os.path.isdir("output"):
print("[I] Creating 'output' directory.")
os.mkdir("output")
image.save(target)


def tripy_diffusion(args):
run_start_time = time.perf_counter()

Expand All @@ -111,19 +132,23 @@ def tripy_diffusion(args):
clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True)
unet_compiled = compile_unet(model, dtype, verbose=True)
vae_compiled = compile_vae(model.decode, dtype, verbose=True)

os.mkdir(args.engine_dir)
print(f"[I] Saving engines to {args.engine_dir}...")
print(f"[I] Saving engines to ./{args.engine_dir}...")
clip_compiled.save(os.path.join("engines", "clip_executable.json"))
unet_compiled.save(os.path.join("engines", "unet_executable.json"))
vae_compiled.save(os.path.join("engines", "vae_executable.json"))

# Run through CLIP to get context from prompt
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
torch_prompt = tokenizer(args.prompt, padding="max_length", max_length=CLIPConfig.max_seq_len, truncation=True, return_tensors="pt")
torch_prompt = tokenizer(
args.prompt, padding="max_length", max_length=CLIPConfig.max_seq_len, truncation=True, return_tensors="pt"
)
prompt = tp.Tensor(torch_prompt.input_ids.to(torch.int32).to("cuda"))
print(f"[I] Got tokenized prompt.")
torch_unconditional_prompt = tokenizer([""], padding="max_length", max_length=CLIPConfig.max_seq_len, return_tensors="pt")
torch_unconditional_prompt = tokenizer(
[""], padding="max_length", max_length=CLIPConfig.max_seq_len, return_tensors="pt"
)
unconditional_prompt = tp.Tensor(torch_unconditional_prompt.input_ids.to(torch.int32).to("cuda"))
print(f"[I] Got unconditional tokenized prompt.")

Expand Down Expand Up @@ -157,16 +182,11 @@ def tripy_diffusion(args):
run_end_time = time.perf_counter()
print(f"[I] Full script took {run_end_time - run_start_time} seconds.")

# Save image
image = Image.fromarray(cp.from_dlpack(x).get().astype(np.uint8, copy=False))
print(f"[I] Saving {args.out}")
if not os.path.isdir("output"):
print("[I] Creating 'output' directory.")
os.mkdir("output")
image.save(args.out)

return image, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end]


# referenced from https://huggingface.co/blog/stable_diffusion
def hf_diffusion(args):
from transformers import CLIPTextModel, CLIPTokenizer
Expand All @@ -176,23 +196,35 @@ def hf_diffusion(args):
run_start_time = time.perf_counter()

dtype = torch.float16 if args.fp16 else torch.float32
model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if args.fp16 else {}
model_opts = {"variant": "fp16", "torch_dtype": torch.float16} if args.fp16 else {}

# Initialize models
model_id = "KiwiXR/stable-diffusion-v1-5"
model_id = "KiwiXR/stable-diffusion-v1-5"

print("[I] Loading models...")
hf_tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
hf_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to("cuda")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", use_auth_token=args.hf_token, **model_opts).to("cuda")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_auth_token=args.hf_token, **model_opts).to("cuda")
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
unet = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", use_auth_token=args.hf_token, **model_opts
).to("cuda")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_auth_token=args.hf_token, **model_opts).to(
"cuda"
)
scheduler = LMSDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)

# Run through CLIP to get context from prompt
print("[I] Starting tokenization and running clip...", end=" ")
clip_run_start = time.perf_counter()
text_input = hf_tokenizer(args.prompt, padding="max_length", max_length=hf_tokenizer.model_max_length, truncation=True, return_tensors="pt").to("cuda")
max_length = text_input.input_ids.shape[-1] # 77
text_input = hf_tokenizer(
args.prompt,
padding="max_length",
max_length=hf_tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).to("cuda")
max_length = text_input.input_ids.shape[-1] # 77
uncond_input = hf_tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt").to("cuda")
text_embeddings = hf_encoder(text_input.input_ids, output_hidden_states=True)[0]
uncond_embeddings = hf_encoder(uncond_input.input_ids)[0]
Expand All @@ -205,7 +237,7 @@ def hf_diffusion(args):
torch.manual_seed(args.seed)
torch_latent = torch.randn((1, 4, 64, 64), dtype=dtype).to("cuda")
torch_latent *= scheduler.init_noise_sigma

scheduler.set_timesteps(args.steps)

diffusion_run_start = time.perf_counter()
Expand Down Expand Up @@ -248,54 +280,54 @@ def hf_diffusion(args):
run_end_time = time.perf_counter()
print(f"[I] Full script took {run_end_time - run_start_time} seconds.")

# Save image
print(f"[I] Saving {args.out}")
if not os.path.isdir("output"):
print("[I] Creating 'output' directory.")
os.mkdir("output")
image.save(args.out)
return image, [clip_run_start, clip_run_end, diffusion_run_start, diffusion_run_end, vae_run_start, vae_run_end]


def print_summary(denoising_steps, times):
stages_ms = [1000 * (times[i+1] - times[i]) for i in range(0, 6, 2)]
stages_ms = [1000 * (times[i + 1] - times[i]) for i in range(0, 6, 2)]
total_ms = sum(stages_ms)
print('|-----------------|--------------|')
print('| {:^15} | {:^12} |'.format('Module', 'Latency'))
print('|-----------------|--------------|')
print('| {:^15} | {:>9.2f} ms |'.format('CLIP', stages_ms[0]))
print('| {:^15} | {:>9.2f} ms |'.format('UNet'+' x '+str(denoising_steps), stages_ms[1]))
print('| {:^15} | {:>9.2f} ms |'.format('VAE-Dec', stages_ms[2]))
print('|-----------------|--------------|')
print('| {:^15} | {:>9.2f} ms |'.format('Pipeline', total_ms))
print('|-----------------|--------------|')
print('Throughput: {:.2f} image/s'.format(1000. / total_ms))


# TODO: Add torch compilation modes
# TODO: Add Timing context
print("|-----------------|--------------|")
print("| {:^15} | {:^12} |".format("Module", "Latency"))
print("|-----------------|--------------|")
print("| {:^15} | {:>9.2f} ms |".format("CLIP", stages_ms[0]))
print("| {:^15} | {:>9.2f} ms |".format("UNet" + " x " + str(denoising_steps), stages_ms[1]))
print("| {:^15} | {:>9.2f} ms |".format("VAE-Dec", stages_ms[2]))
print("|-----------------|--------------|")
print("| {:^15} | {:>9.2f} ms |".format("Pipeline", total_ms))
print("|-----------------|--------------|")
print("Throughput: {:.2f} image/s".format(1000.0 / total_ms))


# TODO: Add torch compilation
# TODO: Add Timing context (depends on how we measure perf)
def main():
default_prompt = "a horse sized cat eating a bagel"
parser = argparse.ArgumentParser(
description="Run Stable Diffusion", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--steps", type=int, default=10, help="Number of denoising steps in diffusion")
parser.add_argument("--prompt", type=str, default=default_prompt, help="Phrase to render")
parser.add_argument("--out", type=str, default=os.path.join("output", "rendered.png"), help="Output filename")
parser.add_argument("--out", type=str, default=None, help="Output filepath")
parser.add_argument("--fp16", action="store_true", help="Cast the weights to float16")
parser.add_argument("--timing", action="store_true", help="Print timing per step")
parser.add_argument("--seed", type=int, help="Set the random latent seed")
parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength")
parser.add_argument('--torch-inference', action='store_true', help="Run inference with PyTorch (eager mode) instead of TensorRT.")
parser.add_argument('--hf-token', type=str, default='', help="HuggingFace API access token for downloading model checkpoints")
parser.add_argument('--engine-dir', type=str, default='engines', help="Output directory for TensorRT engines")
parser.add_argument(
"--torch-inference", action="store_true", help="Run inference with PyTorch (eager mode) instead of TensorRT."
)
parser.add_argument(
"--hf-token", type=str, default="", help="HuggingFace API access token for downloading model checkpoints"
)
parser.add_argument("--engine-dir", type=str, default="engines", help="Output directory for TensorRT engines")
args = parser.parse_args()

if args.torch_inference:
_, times = hf_diffusion(args)
print_summary(args.steps, times)
image, times = hf_diffusion(args)
else:
_, times = tripy_diffusion(args)
print_summary(args.steps, times)
image, times = tripy_diffusion(args)

save_image(image, args)
print_summary(args.steps, times)


if __name__ == "__main__":
main()
main()
21 changes: 18 additions & 3 deletions tripy/examples/diffusion/weight_loader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import tripy as tp

from diffusers import StableDiffusionPipeline


def load_weights_from_hf(model, hf_model, dtype, debug=False):
tripy_state_dict = model.state_dict()
tripy_keys = tripy_state_dict.keys()

hf_state_dict = hf_model.state_dict()
hf_keys = hf_state_dict.keys()

assert_msg = f"Mismatched keys: {hf_keys} != {tripy_keys}"
if debug and len(hf_keys) != len(tripy_keys):
print("\nERROR: unused weights in HF_state_dict:\n", sorted(list(hf_keys - tripy_keys)))
Expand All @@ -25,7 +40,8 @@ def load_weights_from_hf(model, hf_model, dtype, debug=False):
param = tp.Parameter(weight)
tripy_state_dict[key.removeprefix("text_model.")] = param

model.load_from_state_dict(tripy_state_dict)
model.load_state_dict(tripy_state_dict)


def load_from_diffusers(model, dtype, hf_token, debug=False):
model_id = "KiwiXR/stable-diffusion-v1-5"
Expand All @@ -34,4 +50,3 @@ def load_from_diffusers(model, dtype, hf_token, debug=False):
load_weights_from_hf(model.cond_stage_model.transformer.text_model, pipe.text_encoder, dtype, debug=debug)
load_weights_from_hf(model.model.diffusion_model, pipe.unet, dtype, debug=debug)
load_weights_from_hf(model.first_stage_model, pipe.vae, dtype, debug=debug)

45 changes: 34 additions & 11 deletions tripy/tests/test_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import numpy as np
Expand All @@ -23,22 +37,31 @@ def check_equal(tp_array, torch_tensor, dtype=torch.float32, debug=False):
rel_diff = torch.abs(diff) / (torch.abs(b) + eps)
max_rel_diff = torch.max(rel_diff)
print(f"Maximum relative difference: {max_rel_diff}\n")

assert torch.allclose(torch.from_dlpack(tp_array).to(dtype), torch_tensor.to(dtype)), f"\nTP Array:\n {tp_array} \n!= Torch Tensor:\n {torch_tensor}"

assert torch.allclose(
torch.from_dlpack(tp_array).to(dtype), torch_tensor.to(dtype)
), f"\nTP Array:\n {tp_array} \n!= Torch Tensor:\n {torch_tensor}"


@pytest.mark.l1
class TestConvolution:
def test_ssim(self):
args = Namespace(steps=50, prompt='a beautiful photograph of Mt. Fuji during cherry blossom', out='output/rendered.png', fp16=False, seed=100, guidance=7.5, torch_inference=False)
args = Namespace(
steps=50,
prompt="a beautiful photograph of Mt. Fuji during cherry blossom",
out="temp.png",
fp16=True,
seed=420,
guidance=7.5,
torch_inference=False,
hf_token="",
engine_dir="diffusion_engines",
)
tp_img, _ = tripy_diffusion(args)
print(f"first: {tp_img}")
tp_img = np.array(tp_img.convert('L'))
print(f"second: {tp_img}")
tp_img = np.array(tp_img.convert("L"))

torch_img, _ = hf_diffusion(args)
print(f"third: {torch_img}")
torch_img = np.array(torch_img.convert('L'))
print(f"fourth: {torch_img}")
torch_img = np.array(torch_img.convert("L"))

ssim = structural_similarity(tp_img, torch_img)
print(f"SSIM IS: {ssim}")
assert ssim >= 0.85, "Structural Similarity score expected >= 0.85 but got {ssim}"
assert ssim >= 0.80, f"Structural Similarity score expected >= 0.80 but got {ssim}"

0 comments on commit d9dd478

Please sign in to comment.