Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable NPU for paint your dreams #147

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions demos/paint_your_dreams_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@

def get_available_devices() -> list[str]:
core = ov.Core()
# NPU is not supported with this application
return list({device.split(".")[0] for device in core.available_devices if device != "NPU"})
return list({device.split(".")[0] for device in core.available_devices})


def download_models(model_name: str, safety_checker_model: str) -> None:
Expand All @@ -62,15 +61,43 @@ def download_models(model_name: str, safety_checker_model: str) -> None:
image_processor=AutoProcessor.from_pretrained(safety_checker_dir))


async def load_pipeline(model_name: str, device: str):
async def load_npu_pipeline(model_dir: Path, size: int) -> genai.Text2ImagePipeline:
# NPU requires model static input shape for now
ov_config = {"CACHE_DIR": "cache"}

scheduler = genai.Scheduler.from_config(model_dir / "scheduler" / "scheduler_config.json")

text_encoder = genai.CLIPTextModel(model_dir / "text_encoder")
text_encoder.reshape(1)
text_encoder.compile("NPU", **ov_config)

unet = genai.UNet2DConditionModel(model_dir / "unet")
max_position_embeddings = text_encoder.get_config().max_position_embeddings
unet.reshape(1, size, size, max_position_embeddings)
unet.compile("NPU", **ov_config)

vae = genai.AutoencoderKL(model_dir / "vae_decoder")
vae.reshape(1, size, size)
vae.compile("NPU", **ov_config)

ov_pipeline = genai.Text2ImagePipeline.latent_consistency_model(scheduler, text_encoder, unet, vae)

return ov_pipeline


async def load_pipeline(model_name: str, device: str, size: int) -> genai.Text2ImagePipeline:
if device not in ov_pipelines:
model_dir = MODEL_DIR / model_name
ov_config = {"CACHE_DIR": "cache"}

ov_pipeline = genai.Text2ImagePipeline(model_dir, device, **ov_config)
ov_pipelines[device] = ov_pipeline
if device == "NPU":
ov_pipeline = await load_npu_pipeline(model_dir, size)
else:
ov_pipeline = genai.Text2ImagePipeline(model_dir, device, **ov_config)

ov_pipelines[(device, size)] = ov_pipeline

return ov_pipelines[device]
return ov_pipelines[(device, size)]


async def stop():
Expand All @@ -82,7 +109,7 @@ async def generate_images(prompt: str, seed: int, size: int, guidance_scale: flo
global stop_generating
stop_generating = not endless_generation

ov_pipeline = await load_pipeline(hf_model_name, device)
ov_pipeline = await load_pipeline(hf_model_name, device, size)

while True:
if randomize_seed:
Expand Down Expand Up @@ -112,7 +139,7 @@ async def generate_images(prompt: str, seed: int, size: int, guidance_scale: flo
await asyncio.sleep(0.1)


def build_ui():
def build_ui() -> gr.Interface:
examples = [
"A sail boat on a grass field with mountains in the morning and sunny day",
"Portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour,"
Expand Down Expand Up @@ -195,7 +222,7 @@ def build_ui():
return demo


def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: bool = False, public_interface: bool = False):
def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: bool = False, public_interface: bool = False) -> None:
global hf_model_name
hf_model_name = model_name
server_name = "0.0.0.0" if local_network else None
Expand Down
Loading