Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stable Diffusion demo #100

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tripy/examples/diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Implementing Stable Diffusion With Tripy
akhilg-nv marked this conversation as resolved.
Show resolved Hide resolved

## Introduction

This example demonstrates how to implement a Stable Diffusion model using Tripy APIs.

It's broken up into three components:

1. `model.py` defines the model using `tripy.Module` and associated APIs. `clip_model.py`, `unet_model.py`, `vae_model.py` implement specific components of the diffusion model.
2. `weight_loader.py` loads weights from a HuggingFace checkpoint.
3. `example.py` runs the end-to-end example, taking input text as a command-line argument, running inference, and then displaying the generated output.

The model defaults to running in `float32`, but is recommended to run in `float16` by providing the `--fp16` flag if you have less than 20-24 GB of GPU memory (note that normalization layers will still run in `float32` to preserve accuracy).

## Running The Example

1. Install prerequisites:

```bash
python3 -m pip install -r requirements.txt
```

2. Run the example:

```bash
python3 example.py --seed 0 --steps 50 --prompt "a beautiful photograph of Mt. Fuji during cherry blossom" --fp16 --engine-dir fp16_engines
```
125 changes: 125 additions & 0 deletions tripy/examples/diffusion/clip_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tripy as tp

import tripy as tp
from dataclasses import dataclass

from examples.diffusion.helper import scaled_dot_product_attention

@dataclass
class CLIPConfig:
vocab_size: int = 49408
embedding_size: int = 768
num_heads: int = 12
max_seq_len: int = 77
num_hidden_layers: int = 12
dtype: tp.dtype = tp.float32

class CLIPMLP(tp.Module):
def __init__(self, config: CLIPConfig):
self.fc1 = tp.Linear(config.embedding_size, config.embedding_size * 4, dtype=config.dtype)
self.fc2 = tp.Linear(config.embedding_size * 4, config.embedding_size, dtype=config.dtype)

def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = tp.sigmoid(1.702 * hidden_states) * hidden_states # quick GELU
hidden_states = self.fc2(hidden_states)
return hidden_states


class CLIPAttention(tp.Module):
def __init__(self, config: CLIPConfig):
self.embed_dim = config.embedding_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.v_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.q_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.out_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype)
self.dtype = config.dtype

def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q, k, v = [
tp.transpose(
tp.reshape(x, (bsz, tgt_len, self.num_heads, self.head_dim)),
1,
2,
)
for x in (q, k, v)
]
attn_output = scaled_dot_product_attention(
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask,
)
out = self.out_proj(tp.reshape(tp.transpose(attn_output, 1, 2), (bsz, tgt_len, embed_dim)))
return out


class CLIPEncoderLayer(tp.Module):
def __init__(self, config: CLIPConfig):
self.self_attn = CLIPAttention(config)
self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=tp.float32)
self.mlp = CLIPMLP(config)
self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=tp.float32)

def __call__(self, hidden_states, causal_attention_mask):
residual = hidden_states
hidden_states = tp.cast(self.layer_norm1(tp.cast(hidden_states, self.layer_norm1.dtype)), hidden_states.dtype)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = tp.cast(self.layer_norm2(tp.cast(hidden_states, self.layer_norm2.dtype)), hidden_states.dtype)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states


class CLIPEncoder(tp.Module):
def __init__(self, config: CLIPConfig):
self.layers = [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]

def __call__(self, hidden_states, causal_attention_mask):
for l in self.layers:
hidden_states = l(hidden_states, causal_attention_mask)
return hidden_states


class CLIPTextEmbeddings(tp.Module):
def __init__(self, config: CLIPConfig):
self.token_embedding = tp.Embedding(config.vocab_size, config.embedding_size, dtype=config.dtype)
self.position_embedding = tp.Embedding(config.max_seq_len, config.embedding_size, dtype=config.dtype)

def __call__(self, input_ids, position_ids):
return self.token_embedding(input_ids) + self.position_embedding(position_ids)


class CLIPTextTransformer(tp.Module):
def __init__(self, config: CLIPConfig):
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=tp.float32)
self.max_seq_len = config.max_seq_len

def __call__(self, input_ids):
x = self.embeddings(input_ids, tp.reshape(tp.iota((input_ids.shape[1],), dtype=tp.int32), (1, -1)))
x = self.encoder(x, tp.triu(tp.full((1, 1, self.max_seq_len, self.max_seq_len), float("-inf")), 1))
return tp.cast(self.final_layer_norm(tp.cast(x, self.final_layer_norm.dtype)), x.dtype)
Loading