diff --git a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py index 49a79b9bf..6b1997221 100644 --- a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py +++ b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py @@ -291,7 +291,7 @@ def get_noise_prediction( else: guidance = torch.tensor( [guidance_embedding_scale], device=self.device_torch) - guidance = guidance.expand(latent_model_input.shape[0]) + guidance = guidance.expand(bs) else: guidance = None @@ -416,4 +416,4 @@ def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchD return latents.detach() def get_base_model_version(self): - return "flux.1_kontext" \ No newline at end of file + return "flux.1_kontext"