From 36379e929e9567ea93f321267eb6a4602a1af7a1 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Fri, 20 Dec 2024 14:59:22 +0100 Subject: [PATCH 1/9] Add HPU utils --- src/flux/hpu_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 src/flux/hpu_utils.py diff --git a/src/flux/hpu_utils.py b/src/flux/hpu_utils.py new file mode 100644 index 00000000..a036d918 --- /dev/null +++ b/src/flux/hpu_utils.py @@ -0,0 +1,13 @@ +import torch + + +def load_model_to_hpu(model): + from habana_frameworks.torch.utils.library_loader import load_habana_module + load_habana_module() + + device = "hpu" + if torch.hpu.is_available(): + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model) + model = model.eval().to(torch.device(device)) + return model From a1e9027878ae1d649410b794191fa58ed715b746 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Fri, 20 Dec 2024 14:59:44 +0100 Subject: [PATCH 2/9] Add Intel Gaudi HPU usage in CLIs --- demo_gr.py | 1 + src/flux/cli.py | 18 ++++++---- src/flux/cli_control.py | 17 ++++++---- src/flux/cli_fill.py | 22 ++++++++---- src/flux/cli_redux.py | 11 ++++-- src/flux/util.py | 74 ++++++++++++++++++++++++++++++++++++----- 6 files changed, 114 insertions(+), 29 deletions(-) diff --git a/demo_gr.py b/demo_gr.py index 3b4d022b..ad84e833 100644 --- a/demo_gr.py +++ b/demo_gr.py @@ -91,6 +91,7 @@ def generate_image( dtype=torch.bfloat16, seed=opts.seed, ) + timesteps = get_schedule( opts.num_steps, x.shape[-1] * x.shape[-2] // 4, diff --git a/src/flux/cli.py b/src/flux/cli.py index e844c765..73debfa7 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -9,7 +9,8 @@ from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial, get_dtype +from hpu_utils import load_model_to_hpu NSFW_THRESHOLD = 0.85 @@ -101,7 +102,7 @@ def main( "a photo of a forest with mist swirling around the tree trunks. The word " '"FLUX" is painted over it in big, red brush strokes with visible texture' ), - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = None, num_steps: int | None = None, loop: bool = False, guidance: float = 3.5, @@ -127,6 +128,7 @@ def main( guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata """ + device = get_device_initial(device) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) if name not in configs: @@ -171,6 +173,7 @@ def main( if loop: opts = parse_prompt(opts) + dtype = get_dtype(str(device)) while opts is not None: if opts.seed is None: opts.seed = rng.seed() @@ -183,11 +186,14 @@ def main( opts.height, opts.width, device=torch_device, - dtype=torch.bfloat16, + dtype=dtype, seed=opts.seed, ) + if str(device) == "hpu": + x = load_model_to_hpu(x) + opts.seed = None - if offload: + if offload and str(device) != "hpu": ae = ae.cpu() torch.cuda.empty_cache() t5, clip = t5.to(torch_device), clip.to(torch_device) @@ -204,14 +210,14 @@ def main( x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu - if offload: + if offload and str(device) != "hpu": model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + with torch.autocast(device_type=torch_device.type, dtype=dtype): x = ae.decode(x) if torch.cuda.is_available(): diff --git a/src/flux/cli_control.py b/src/flux/cli_control.py index cd83c89e..c300047a 100644 --- a/src/flux/cli_control.py +++ b/src/flux/cli_control.py @@ -10,7 +10,8 @@ from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial, get_dtype +from hpu_utils import load_model_to_hpu @dataclass @@ -165,7 +166,7 @@ def main( height: int = 1024, seed: int | None = None, prompt: str = "a robot made out of gold", - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = None, num_steps: int = 50, loop: bool = False, guidance: float | None = None, @@ -193,6 +194,7 @@ def main( add_sampling_metadata: Add the prompt to the image Exif metadata img_cond_path: path to conditioning image (jpeg/png/webp) """ + device = get_device_initial(device) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) assert name in [ @@ -268,6 +270,7 @@ def main( if hasattr(module, "set_scale"): module.set_scale(opts.lora_scale) + dtype = get_dtype(device) while opts is not None: if opts.seed is None: opts.seed = rng.seed() @@ -280,9 +283,11 @@ def main( opts.height, opts.width, device=torch_device, - dtype=torch.bfloat16, + dtype=dtype, seed=opts.seed, ) + if str(device) == "hpu": + x = load_model_to_hpu(x) opts.seed = None if offload: t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) @@ -298,7 +303,7 @@ def main( timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs and AE to CPU, load model to gpu - if offload: + if offload and str(device) != "hpu": t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() torch.cuda.empty_cache() model = model.to(torch_device) @@ -307,14 +312,14 @@ def main( x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu - if offload: + if offload and str(device) != "hpu": model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + with torch.autocast(device_type=torch_device.type, dtype=dtype): x = ae.decode(x) if torch.cuda.is_available(): diff --git a/src/flux/cli_fill.py b/src/flux/cli_fill.py index 415c0420..20160c7e 100644 --- a/src/flux/cli_fill.py +++ b/src/flux/cli_fill.py @@ -5,12 +5,13 @@ from glob import iglob import torch -from fire import Fire from PIL import Image +from fire import Fire from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial, get_dtype +from hpu_utils import load_model_to_hpu @dataclass @@ -175,7 +176,7 @@ def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | No def main( seed: int | None = None, prompt: str = "a white paper cup", - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str | None = None, num_steps: int = 50, loop: bool = False, guidance: float = 30.0, @@ -203,6 +204,7 @@ def main( img_cond_path: path to conditioning image (jpeg/png/webp) img_mask_path: path to conditioning mask (jpeg/png/webp """ + device = get_device_initial(device) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) name = "flux-dev-fill" @@ -254,6 +256,7 @@ def main( opts = parse_img_mask_path(opts) + dtype = get_dtype(str(device)) while opts is not None: if opts.seed is None: opts.seed = rng.seed() @@ -266,12 +269,19 @@ def main( opts.height, opts.width, device=torch_device, - dtype=torch.bfloat16, + dtype=dtype, seed=opts.seed, ) + if str(device) == "hpu": + x = load_model_to_hpu(x) + opts.seed = None - if offload: + if offload and str(device) != "hpu": t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch.device) + if str(device) == "hpu": + ae = load_model_to_hpu(ae) + clip = load_model_to_hpu(clip) + inp = prepare_fill( t5, clip, @@ -301,7 +311,7 @@ def main( # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + with torch.autocast(device_type=torch_device.type, dtype=dtype): x = ae.decode(x) if torch.cuda.is_available(): diff --git a/src/flux/cli_redux.py b/src/flux/cli_redux.py index 6c03435a..29c3b478 100644 --- a/src/flux/cli_redux.py +++ b/src/flux/cli_redux.py @@ -10,7 +10,8 @@ from flux.modules.image_embedders import ReduxImageEncoder from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_dtype +from hpu_utils import load_model_to_hpu @dataclass @@ -206,6 +207,7 @@ def main( opts = parse_prompt(opts) opts = parse_img_cond_path(opts) + dtype = get_dtype(str(device)) while opts is not None: if opts.seed is None: opts.seed = rng.seed() @@ -218,9 +220,12 @@ def main( opts.height, opts.width, device=torch_device, - dtype=torch.bfloat16, + dtype=dtype, seed=opts.seed, ) + if str(device) == "hpu": + x = load_model_to_hpu(x) + opts.seed = None if offload: ae = ae.cpu() @@ -253,7 +258,7 @@ def main( # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + with torch.autocast(device_type=torch_device.type, dtype=dtype): x = ae.decode(x) if torch.cuda.is_available(): diff --git a/src/flux/util.py b/src/flux/util.py index 26b9cb26..d6a8b5dd 100644 --- a/src/flux/util.py +++ b/src/flux/util.py @@ -1,3 +1,4 @@ +import importlib.util import os from dataclasses import dataclass @@ -11,6 +12,50 @@ from flux.model import Flux, FluxLoraWrapper, FluxParams from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from flux.modules.conditioner import HFEmbedder +from hpu_utils import load_model_to_hpu + + +def get_device_initial(preferred_device=None): + """ + Determine the appropriate device to use (cuda, hpu, or cpu). + + Args: + preferred_device (str): User-preferred device ('cuda', 'hpu', or 'cpu'). + + Returns: + str: Device string ('cuda', 'hpu', or 'cpu'). + """ + # Check for HPU support + if importlib.util.find_spec("habana_frameworks") is not None: + from habana_frameworks.torch.utils.library_loader import load_habana_module + + load_habana_module() + if torch.hpu.is_available(): + if preferred_device == "hpu" or preferred_device is None: + return "hpu" + + # Check for CUDA (GPU support) + if torch.cuda.is_available(): + if preferred_device == "cuda" or preferred_device is None: + return "cuda" + + # Default to CPU + return "cpu" + + +def get_dtype(device: str) -> torch.dtype: + """ + Determine the appropriate dtype to use based on the device. + + Args: + device (str): Device string ('cuda', 'hpu', or 'cpu'). + + Returns: + torch.dtype: Data type (torch.float32 or torch.bfloat16). + """ + if "hpu" in device: + return torch.float32 + return torch.bfloat16 def save_image( @@ -314,7 +359,7 @@ def print_load_warning(missing: list[str], unexpected: list[str]) -> None: def load_flow_model( - name: str, device: str | torch.device = "cuda", hf_download: bool = True, verbose: bool = False + name: str, device: str | torch.device = get_device_initial(), hf_download: bool = True, verbose: bool = False ) -> Flux: # Loading Flux print("Init model") @@ -328,11 +373,12 @@ def load_flow_model( ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + dtype = get_dtype(str(device)) with torch.device("meta" if ckpt_path is not None else device): if lora_path is not None: - model = FluxLoraWrapper(params=configs[name].params).to(torch.bfloat16) + model = FluxLoraWrapper(params=configs[name].params).to(dtype) else: - model = Flux(configs[name].params).to(torch.bfloat16) + model = Flux(configs[name].params).to(dtype) if ckpt_path is not None: print("Loading checkpoint") @@ -353,16 +399,28 @@ def load_flow_model( return model -def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: +def load_t5(device: str | torch.device = get_device_initial(), max_length: int = 512) -> HFEmbedder: + dtype = get_dtype(str(device)) # max length 64, 128, 256 and 512 should work (if your sequence is short enough) - return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + model_init = HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=dtype) + if str(device) == "hpu": + """ Load the model to HPU """ + model = load_model_to_hpu(model_init) + return model + return model_init.to(device) -def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: - return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) +def load_clip(device: str | torch.device = get_device_initial()) -> HFEmbedder: + dtype = get_dtype(str(device)) + model_init = HFEmbedder("openai/clip-vit-base-patch16", max_length=77, torch_dtype=dtype) + if str(device) == "hpu": + """ Load the model to HPU """ + model = load_model_to_hpu(model_init) + return model + return model_init.to(device) -def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: +def load_ae(name: str, device: str | torch.device = get_device_initial(), hf_download: bool = True) -> AutoEncoder: ckpt_path = configs[name].ae_path if ( ckpt_path is None From cc6f1e97800f247fd84d367029fffb7e81c9f017 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Fri, 20 Dec 2024 14:59:57 +0100 Subject: [PATCH 3/9] Add Dockerfile and requirements --- Dockerfile.hpu | 26 ++++++++++++++++++++++++++ requirements_hpu.txt | 6 ++++++ 2 files changed, 32 insertions(+) create mode 100644 Dockerfile.hpu create mode 100644 requirements_hpu.txt diff --git a/Dockerfile.hpu b/Dockerfile.hpu new file mode 100644 index 00000000..8ef2dd46 --- /dev/null +++ b/Dockerfile.hpu @@ -0,0 +1,26 @@ +# Use the official Gaudi Docker image with PyTorch +FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + +# Set environment variables for Habana +ENV HABANA_VISIBLE_DEVICES=all +ENV OMPI_MCA_btl_vader_single_copy_mechanism=none +ENV PT_HPU_LAZY_ACC_PAR_MODE=0 +ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=1 + +# Set timezone to UTC and install essential packages +ENV DEBIAN_FRONTEND="noninteractive" TZ=Etc/UTC +RUN apt-get update && apt-get install -y \ + tzdata \ + python3-pip \ + && rm -rf /var/lib/apt/lists/* + +COPY . /workspace/flux +WORKDIR /workspace/flux + +# Copy HPU requirements +COPY requirements_hpu.txt /workspace/requirements_hpu.txt + +# Install Python packages +RUN pip install --upgrade pip \ + && pip install -e ".[all]" \ + && pip install -r requirements_hpu.txt diff --git a/requirements_hpu.txt b/requirements_hpu.txt new file mode 100644 index 00000000..8f6e6dba --- /dev/null +++ b/requirements_hpu.txt @@ -0,0 +1,6 @@ +optimum-habana==1.14.1 +transformers==4.45.2 +huggingface-hub==0.26.2 +tiktoken==0.8.0 +torch-geometric==2.6.1 +numba==0.60.0 From c45eb3f0687987f7edef096151e8d64f24299d5d Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Fri, 20 Dec 2024 17:54:53 +0100 Subject: [PATCH 4/9] Update the imports and loading HPU model --- src/flux/cli.py | 9 +++++--- src/flux/cli_control.py | 4 ++-- src/flux/cli_fill.py | 11 ++++++---- src/flux/cli_redux.py | 4 ++-- src/flux/hpu_utils.py | 15 ++++++++++++++ src/flux/modules/image_embedders.py | 20 ++++++++++++++++-- src/flux/modules/layers.py | 12 +++++++---- src/flux/sampling.py | 32 +++++++++++++++++++++-------- src/flux/util.py | 17 +-------------- 9 files changed, 82 insertions(+), 42 deletions(-) diff --git a/src/flux/cli.py b/src/flux/cli.py index 73debfa7..6ee1c3d5 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -9,8 +9,8 @@ from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial, get_dtype -from hpu_utils import load_model_to_hpu +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial +from flux.hpu_utils import load_model_to_hpu, get_dtype NSFW_THRESHOLD = 0.85 @@ -201,11 +201,14 @@ def main( timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs to CPU, load model to gpu - if offload: + if offload and str(device) != "hpu": t5, clip = t5.cpu(), clip.cpu() torch.cuda.empty_cache() model = model.to(torch_device) + if str(device) == "hpu": + model = load_model_to_hpu(model) + # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) diff --git a/src/flux/cli_control.py b/src/flux/cli_control.py index c300047a..18a9366e 100644 --- a/src/flux/cli_control.py +++ b/src/flux/cli_control.py @@ -10,8 +10,8 @@ from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial, get_dtype -from hpu_utils import load_model_to_hpu +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial +from flux.hpu_utils import load_model_to_hpu, get_dtype @dataclass diff --git a/src/flux/cli_fill.py b/src/flux/cli_fill.py index 20160c7e..3b2e8100 100644 --- a/src/flux/cli_fill.py +++ b/src/flux/cli_fill.py @@ -10,8 +10,8 @@ from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial, get_dtype -from hpu_utils import load_model_to_hpu +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_device_initial +from flux.hpu_utils import load_model_to_hpu, get_dtype @dataclass @@ -295,16 +295,19 @@ def main( timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs and AE to CPU, load model to gpu - if offload: + if offload and str(device) != "hpu": t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() torch.cuda.empty_cache() model = model.to(torch_device) + if str(device) == "hpu": + model = load_model_to_hpu(model) + # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu - if offload: + if offload and str(device) != "hpu": model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) diff --git a/src/flux/cli_redux.py b/src/flux/cli_redux.py index 29c3b478..314c7014 100644 --- a/src/flux/cli_redux.py +++ b/src/flux/cli_redux.py @@ -10,8 +10,8 @@ from flux.modules.image_embedders import ReduxImageEncoder from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image, get_dtype -from hpu_utils import load_model_to_hpu +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image +from flux.hpu_utils import load_model_to_hpu, get_dtype @dataclass diff --git a/src/flux/hpu_utils.py b/src/flux/hpu_utils.py index a036d918..9d99005d 100644 --- a/src/flux/hpu_utils.py +++ b/src/flux/hpu_utils.py @@ -11,3 +11,18 @@ def load_model_to_hpu(model): model = wrap_in_hpu_graph(model) model = model.eval().to(torch.device(device)) return model + + +def get_dtype(device: str) -> torch.dtype: + """ + Determine the appropriate dtype to use based on the device. + + Args: + device (str): Device string ('cuda', 'hpu', or 'cpu'). + + Returns: + torch.dtype: Data type (torch.float32 or torch.bfloat16). + """ + if "hpu" in device: + return torch.float32 + return torch.bfloat16 diff --git a/src/flux/modules/image_embedders.py b/src/flux/modules/image_embedders.py index e7177d2f..bd84ac1c 100644 --- a/src/flux/modules/image_embedders.py +++ b/src/flux/modules/image_embedders.py @@ -10,6 +10,7 @@ from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel from flux.util import print_load_warning +from flux.hpu_utils import load_model_to_hpu class DepthImageEncoder: @@ -17,8 +18,20 @@ class DepthImageEncoder: def __init__(self, device): self.device = device - self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device) - self.processor = AutoProcessor.from_pretrained(self.depth_model_name) + self.depth_model = self._get_depth_model() + self.processor = self._get_processor() + + def _get_depth_model(self): + _model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name) + if str(self.device) == "hpu": + return load_model_to_hpu(self.depth_model) + return _model.to(self.device) + + def _get_processor(self): + _processor = AutoProcessor.from_pretrained(self.depth_model_name) + if str(self.device) == "hpu": + return load_model_to_hpu(_processor) + return _processor.to(self.device) def __call__(self, img: torch.Tensor) -> torch.Tensor: hw = img.shape[-2:] @@ -60,6 +73,9 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: canny = torch.from_numpy(canny).float() / 127.5 - 1.0 canny = rearrange(canny, "h w -> 1 1 h w") canny = repeat(canny, "b 1 ... -> b 3 ...") + + if str(self.device) == "hpu": + return load_model_to_hpu(canny) return canny.to(self.device) diff --git a/src/flux/modules/layers.py b/src/flux/modules/layers.py index 091ddf62..42fe250c 100644 --- a/src/flux/modules/layers.py +++ b/src/flux/modules/layers.py @@ -1,11 +1,12 @@ -import math from dataclasses import dataclass import torch from einops import rearrange from torch import Tensor, nn +import math from flux.math import attention, rope +from flux.hpu_utils import load_model_to_hpu class EmbedND(nn.Module): @@ -36,9 +37,12 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 """ t = time_factor * t half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - t.device - ) + + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + if str(t.device) == "hpu": + freqs = load_model_to_hpu(freqs) + else: + freqs = freqs.to(t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) diff --git a/src/flux/sampling.py b/src/flux/sampling.py index 048b76cf..ed0ee856 100644 --- a/src/flux/sampling.py +++ b/src/flux/sampling.py @@ -1,12 +1,13 @@ -import math from typing import Callable import numpy as np import torch -from einops import rearrange, repeat from PIL import Image +from einops import rearrange, repeat from torch import Tensor +import math +from flux.hpu_utils import get_dtype, load_model_to_hpu from .model import Flux from .modules.autoencoder import AutoEncoder from .modules.conditioner import HFEmbedder @@ -94,7 +95,10 @@ def prepare_control( img_cond = encoder(img_cond) img_cond = ae.encode(img_cond) - img_cond = img_cond.to(torch.bfloat16) + # get dtype + dtype = get_dtype(str(img.device)) + + img_cond = img_cond.to(dtype) img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) @@ -128,13 +132,18 @@ def prepare_fill( mask = torch.from_numpy(mask).float() / 255.0 mask = rearrange(mask, "h w -> 1 1 h w") + dtype = get_dtype(str(img.device)) with torch.no_grad(): - img_cond = img_cond.to(img.device) - mask = mask.to(img.device) + if str(img.device) == "hpu": + img_cond = load_model_to_hpu(img_cond) + mask = load_model_to_hpu(mask) + else: + img_cond = img_cond.to(img.device) + mask = mask.to(img.device) img_cond = img_cond * (1 - mask) img_cond = ae.encode(img_cond) mask = mask[:, 0, :, :] - mask = mask.to(torch.bfloat16) + mask = mask.to(dtype) mask = rearrange( mask, "b (h ph) (w pw) -> b (ph pw) h w", @@ -145,7 +154,7 @@ def prepare_fill( if mask.shape[0] == 1 and bs > 1: mask = repeat(mask, "1 ... -> bs ...", bs=bs) - img_cond = img_cond.to(torch.bfloat16) + img_cond = img_cond.to(dtype) img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) @@ -153,7 +162,11 @@ def prepare_fill( img_cond = torch.cat((img_cond, mask), dim=-1) return_dict = prepare(t5, clip, img, prompt) - return_dict["img_cond"] = img_cond.to(img.device) + + if str(img.device) == "hpu": + return_dict["img_cond"] = load_model_to_hpu(img_cond) + else: + return_dict["img_cond"] = img_cond.to(img.device) return return_dict @@ -166,6 +179,7 @@ def prepare_redux( img_cond_path: str, ) -> dict[str, Tensor]: bs, _, h, w = img.shape + dtype = get_dtype(str(img.device)) if bs == 1 and not isinstance(prompt, str): bs = len(prompt) @@ -173,7 +187,7 @@ def prepare_redux( with torch.no_grad(): img_cond = encoder(img_cond) - img_cond = img_cond.to(torch.bfloat16) + img_cond = img_cond.to(dtype) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) diff --git a/src/flux/util.py b/src/flux/util.py index d6a8b5dd..606e3b12 100644 --- a/src/flux/util.py +++ b/src/flux/util.py @@ -12,7 +12,7 @@ from flux.model import Flux, FluxLoraWrapper, FluxParams from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from flux.modules.conditioner import HFEmbedder -from hpu_utils import load_model_to_hpu +from flux.hpu_utils import load_model_to_hpu def get_device_initial(preferred_device=None): @@ -43,21 +43,6 @@ def get_device_initial(preferred_device=None): return "cpu" -def get_dtype(device: str) -> torch.dtype: - """ - Determine the appropriate dtype to use based on the device. - - Args: - device (str): Device string ('cuda', 'hpu', or 'cpu'). - - Returns: - torch.dtype: Data type (torch.float32 or torch.bfloat16). - """ - if "hpu" in device: - return torch.float32 - return torch.bfloat16 - - def save_image( nsfw_classifier, name: str, From 8e5669cb4e1b5af22b5806c420b778e7ef5032b7 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Fri, 20 Dec 2024 18:51:55 +0100 Subject: [PATCH 5/9] Add Demo for HPU and update README.md --- README.md | 22 +++++ demo_hpu.py | 260 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 demo_hpu.py diff --git a/README.md b/README.md index e73c1f13..a32d24b9 100644 --- a/README.md +++ b/README.md @@ -85,3 +85,25 @@ $ python -m flux.api --prompt="A beautiful beach" save outputs/api # open the image directly $ python -m flux.api --prompt="A beautiful beach" image show ``` + + +## Intel® Gaudi® HPU Usage + +### Build the Docker Image +To use Intel® Gaudi® HPU for running this notebook, start by building a Docker image with the appropriate environment setup. + +```bash +docker build -t flux_hpu:latest -f Dockerfile.hpu . +``` + +In the `Dockerfile.hpu`, we use the `vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest` base image. Ensure that the version matches your setup. +See the [PyTorch Docker Images for the Intel® Gaudi® Accelerator](https://developer.habana.ai/catalog/pytorch-container/) for more information. + +### Run the Container + +```bash +docker run -it --runtime=habana flux_hpu:latest +``` + +Optionally, you can add a mapping volume (`-v`) to access your project directory inside the container. Add the flag `-v /path/to/your/project:/workspace/project` to the `docker run` command. +Replace `/path/to/your/project` with the path to your project directory on your local machine. diff --git a/demo_hpu.py b/demo_hpu.py new file mode 100644 index 00000000..7b2b4961 --- /dev/null +++ b/demo_hpu.py @@ -0,0 +1,260 @@ +import os +import re +import time +from glob import iglob +from io import BytesIO + +import streamlit as st +import torch +from PIL import ExifTags, Image +from einops import rearrange +from fire import Fire +from st_keyup import st_keyup +from torchvision import transforms +from transformers import pipeline + +from flux.cli import SamplingOptions +from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from flux.util import ( + configs, + embed_watermark, + load_ae, + load_clip, + load_flow_model, + load_t5, +) + +# Ensure that Habana SynapseAI is available +try: + import habana_frameworks.torch.core as htcore + from habana_frameworks.torch.hpu import lazy_mode +except ImportError: + st.error("Habana SynapseAI library is not installed. Please install it to use HPU.") + +NSFW_THRESHOLD = 0.85 + + +@st.cache_resource() +def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): + t5 = load_t5(device, max_length=256 if is_schnell else 512) + clip = load_clip(device) + model = load_flow_model(name, device="cpu" if offload else device) + ae = load_ae(name, device="cpu" if offload else device) + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) + return model, ae, t5, clip, nsfw_classifier + + +def get_image() -> torch.Tensor | None: + image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) + if image is None: + return None + image = Image.open(image).convert("RGB") + + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Lambda(lambda x: 2.0 * x - 1.0), + ] + ) + img: torch.Tensor = transform(image) + return img[None, ...] + + +@torch.inference_mode() +def main( + device: str = "hpu", + offload: bool = False, + output_dir: str = "output", +): + if device == "hpu": + # Ensure lazy mode is enabled for Habana devices + lazy_mode(True) + + torch_device = torch.device(device) + names = list(configs.keys()) + name = st.selectbox("Which model to load?", names) + if name is None or not st.checkbox("Load model", False): + return + + is_schnell = name == "flux-schnell" + model, ae, t5, clip, nsfw_classifier = get_models( + name, + device=torch_device, + offload=offload, + is_schnell=is_schnell, + ) + + do_img2img = ( + st.checkbox( + "Image to Image", + False, + disabled=is_schnell, + help="Partially noise an image and denoise again to get variations.\n\nOnly works for flux-dev", + ) + and not is_schnell + ) + if do_img2img: + init_image = get_image() + if init_image is None: + st.warning("Please add an image to do image to image") + image2image_strength = st.number_input("Noising strength", min_value=0.0, max_value=1.0, value=0.8) + if init_image is not None: + h, w = init_image.shape[-2:] + st.write(f"Got image of size {w}x{h} ({h * w / 1e6:.2f}MP)") + resize_img = st.checkbox("Resize image", False) or init_image is None + else: + init_image = None + resize_img = True + image2image_strength = 0.0 + + width = int( + 16 * (st.number_input("Width", min_value=128, value=1360, step=16, disabled=not resize_img) // 16) + ) + height = int( + 16 * (st.number_input("Height", min_value=128, value=768, step=16, disabled=not resize_img) // 16) + ) + num_steps = int(st.number_input("Number of steps", min_value=1, value=(4 if is_schnell else 50))) + guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell)) + seed_str = st.text_input("Seed", disabled=is_schnell) + if seed_str.isdecimal(): + seed = int(seed_str) + else: + st.info("No seed set, set to positive integer to enable") + seed = None + save_samples = st.checkbox("Save samples?", not is_schnell) + add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True) + + default_prompt = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ) + prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text") + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + rng = torch.Generator(device="cpu") + + if "seed" not in st.session_state: + st.session_state.seed = rng.seed() + + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if st.button("Sample"): + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating '{opts.prompt}' with seed {opts.seed}") + t0 = time.perf_counter() + + if init_image is not None: + if resize_img: + init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) + else: + h, w = init_image.shape[-2:] + init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)] + opts.height = init_image.shape[-2] + opts.width = init_image.shape[-1] + if offload: + ae.encoder.to(torch_device) + init_image = ae.encode(init_image.to(torch_device)) + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + timesteps = get_schedule( + opts.num_steps, + (x.shape[-1] * x.shape[-2]) // 4, + ) + if init_image is not None: + t_idx = int((1 - image2image_strength) * num_steps) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + x = t * x + (1.0 - t) * init_image.to(x.dtype) + + inp = prepare(t5=t5, clip=clip, img=x, prompt=opts.prompt) + + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + t1 = time.perf_counter() + + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s.") + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + + if nsfw_score < NSFW_THRESHOLD: + buffer = BytesIO() + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0) + + img_bytes = buffer.getvalue() + if save_samples: + print(f"Saving {fn}") + with open(fn, "wb") as file: + file.write(img_bytes) + idx += 1 + + st.session_state["samples"] = { + "prompt": opts.prompt, + "img": img, + "seed": opts.seed, + "bytes": img_bytes, + } + opts.seed = None + else: + st.warning("Your generated image may contain NSFW content.") + st.session_state["samples"] = None + + samples = st.session_state.get("samples", None) + if samples is not None: + st.image(samples["img"], caption=samples["prompt"]) + st.download_button( + "Download full-resolution", + samples["bytes"], + file_name="generated.jpg", + mime="image/jpg", + ) + st.write(f"Seed: {samples['seed']}") + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() From bb5f3f1f3444b29df93a939a352097b5958cf935 Mon Sep 17 00:00:00 2001 From: Piotr Sobieszczyk <51410953+Sobiechh@users.noreply.github.com> Date: Fri, 31 Jan 2025 15:50:09 +0100 Subject: [PATCH 6/9] Use clip and t5 models from optimum-habana package --- src/flux/hpu_utils.py | 42 +++++++++++++++++++++++++++++++++++++----- src/flux/util.py | 6 ++++-- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/flux/hpu_utils.py b/src/flux/hpu_utils.py index 9d99005d..3c444a8d 100644 --- a/src/flux/hpu_utils.py +++ b/src/flux/hpu_utils.py @@ -1,16 +1,48 @@ +from typing import Optional + import torch -def load_model_to_hpu(model): +def load_model_to_hpu(model, model_name=Optional[str]): from habana_frameworks.torch.utils.library_loader import load_habana_module load_habana_module() - device = "hpu" - if torch.hpu.is_available(): + if not torch.hpu.is_available(): + return model + + # Check if model is HFEmbedder (which wraps CLIP or T5) + if model.__class__.__name__ == "HFEmbedder" and model_name: + if "t5" in model_name.lower(): + from transformers.models.t5.modeling_t5 import ( + T5ForConditionalGeneration + ) + from optimum.habana.transformers.models.t5.modeling_t5 import ( + gaudi_T5ForConditionalGeneration_forward, + gaudi_T5ForConditionalGeneration_prepare_inputs_for_generation + ) + + # Apply HPU-specific optimizations + T5ForConditionalGeneration.forward = gaudi_T5ForConditionalGeneration_forward + T5ForConditionalGeneration.prepare_inputs_for_generation = gaudi_T5ForConditionalGeneration_prepare_inputs_for_generation + + # Initialize optimized model + model = T5ForConditionalGeneration.from_pretrained(model_name) + + elif "clip" in model_name.lower(): + from optimum.habana.transformers.models.clip import GaudiCLIPVisionModel + model = GaudiCLIPVisionModel.from_pretrained( + model_name, + use_flash_attention=True, + flash_attention_recompute=False + ) + + # Fallback to regular HPU loading + else: from habana_frameworks.torch.hpu import wrap_in_hpu_graph model = wrap_in_hpu_graph(model) - model = model.eval().to(torch.device(device)) - return model + model = model.to(torch.device("hpu")) + + return model.eval() def get_dtype(device: str) -> torch.dtype: diff --git a/src/flux/util.py b/src/flux/util.py index 606e3b12..7cf7df3e 100644 --- a/src/flux/util.py +++ b/src/flux/util.py @@ -387,7 +387,8 @@ def load_flow_model( def load_t5(device: str | torch.device = get_device_initial(), max_length: int = 512) -> HFEmbedder: dtype = get_dtype(str(device)) # max length 64, 128, 256 and 512 should work (if your sequence is short enough) - model_init = HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=dtype) + model_name = "google/t5-v1_1-xxl" + model_init = HFEmbedder(model_name, max_length=max_length, torch_dtype=dtype) if str(device) == "hpu": """ Load the model to HPU """ model = load_model_to_hpu(model_init) @@ -397,7 +398,8 @@ def load_t5(device: str | torch.device = get_device_initial(), max_length: int = def load_clip(device: str | torch.device = get_device_initial()) -> HFEmbedder: dtype = get_dtype(str(device)) - model_init = HFEmbedder("openai/clip-vit-base-patch16", max_length=77, torch_dtype=dtype) + model_name = "openai/clip-vit-base-patch16" + model_init = HFEmbedder(model_name, max_length=77, torch_dtype=dtype) if str(device) == "hpu": """ Load the model to HPU """ model = load_model_to_hpu(model_init) From 3da9fd5f72525da1fb4570cc31e45c957167aa66 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Fri, 7 Feb 2025 17:46:55 +0100 Subject: [PATCH 7/9] Add `adapt_transformers_to_gaudi` usage --- src/flux/hpu_utils.py | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/flux/hpu_utils.py b/src/flux/hpu_utils.py index 3c444a8d..7699c6e2 100644 --- a/src/flux/hpu_utils.py +++ b/src/flux/hpu_utils.py @@ -2,49 +2,36 @@ import torch - -def load_model_to_hpu(model, model_name=Optional[str]): +def load_model_to_hpu(model, model_name: Optional[str] = None): from habana_frameworks.torch.utils.library_loader import load_habana_module + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi load_habana_module() if not torch.hpu.is_available(): return model - # Check if model is HFEmbedder (which wraps CLIP or T5) + # Adapt transformers models to Gaudi for optimization + adapt_transformers_to_gaudi() + if model.__class__.__name__ == "HFEmbedder" and model_name: if "t5" in model_name.lower(): - from transformers.models.t5.modeling_t5 import ( - T5ForConditionalGeneration - ) - from optimum.habana.transformers.models.t5.modeling_t5 import ( - gaudi_T5ForConditionalGeneration_forward, - gaudi_T5ForConditionalGeneration_prepare_inputs_for_generation - ) - - # Apply HPU-specific optimizations - T5ForConditionalGeneration.forward = gaudi_T5ForConditionalGeneration_forward - T5ForConditionalGeneration.prepare_inputs_for_generation = gaudi_T5ForConditionalGeneration_prepare_inputs_for_generation - - # Initialize optimized model - model = T5ForConditionalGeneration.from_pretrained(model_name) - + from transformers import T5ForConditionalGeneration + model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16) elif "clip" in model_name.lower(): from optimum.habana.transformers.models.clip import GaudiCLIPVisionModel model = GaudiCLIPVisionModel.from_pretrained( model_name, use_flash_attention=True, - flash_attention_recompute=False + flash_attention_recompute=False, + torch_dtype=torch.bfloat16 ) - - # Fallback to regular HPU loading else: from habana_frameworks.torch.hpu import wrap_in_hpu_graph model = wrap_in_hpu_graph(model) - model = model.to(torch.device("hpu")) + model = model.to(torch.device("hpu"), dtype=torch.bfloat16) return model.eval() - def get_dtype(device: str) -> torch.dtype: """ Determine the appropriate dtype to use based on the device. @@ -56,5 +43,5 @@ def get_dtype(device: str) -> torch.dtype: torch.dtype: Data type (torch.float32 or torch.bfloat16). """ if "hpu" in device: - return torch.float32 - return torch.bfloat16 + return torch.bfloat16 + return torch.float32 From 2f53dfc11d37c4e89493d0fc0fd4b88eecbfdb6d Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Wed, 12 Feb 2025 14:16:02 +0100 Subject: [PATCH 8/9] Add timing measurements for HPU vs CPU using time.perf_counter() --- demo_hpu.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/demo_hpu.py b/demo_hpu.py index 7b2b4961..d080eae8 100644 --- a/demo_hpu.py +++ b/demo_hpu.py @@ -159,6 +159,11 @@ def main( if opts.seed is None: opts.seed = rng.seed() print(f"Generating '{opts.prompt}' with seed {opts.seed}") + + # Timing for HPU + if device == "hpu": + t0_hpu = time.perf_counter() + t0 = time.perf_counter() if init_image is not None: @@ -203,6 +208,10 @@ def main( t1 = time.perf_counter() + if device == "hpu": + t1_hpu = time.perf_counter() + st.write(f"HPU execution time: {t1_hpu - t0_hpu:.1f}s") + fn = output_name.format(idx=idx) print(f"Done in {t1 - t0:.1f}s.") x = x.clamp(-1, 1) From 2d49c283ead4d1729a822c27e111b2498ec94e69 Mon Sep 17 00:00:00 2001 From: PiotrBLL Date: Wed, 12 Feb 2025 14:22:32 +0100 Subject: [PATCH 9/9] Add usage of CLIPVisionModel --- src/flux/hpu_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/flux/hpu_utils.py b/src/flux/hpu_utils.py index 7699c6e2..3ab9ecbb 100644 --- a/src/flux/hpu_utils.py +++ b/src/flux/hpu_utils.py @@ -2,6 +2,7 @@ import torch + def load_model_to_hpu(model, model_name: Optional[str] = None): from habana_frameworks.torch.utils.library_loader import load_habana_module from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -18,8 +19,8 @@ def load_model_to_hpu(model, model_name: Optional[str] = None): from transformers import T5ForConditionalGeneration model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16) elif "clip" in model_name.lower(): - from optimum.habana.transformers.models.clip import GaudiCLIPVisionModel - model = GaudiCLIPVisionModel.from_pretrained( + from transformers.models.clip.modeling_clip import CLIPVisionModel + model = CLIPVisionModel.from_pretrained( model_name, use_flash_attention=True, flash_attention_recompute=False, @@ -32,6 +33,7 @@ def load_model_to_hpu(model, model_name: Optional[str] = None): return model.eval() + def get_dtype(device: str) -> torch.dtype: """ Determine the appropriate dtype to use based on the device.