diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index a758d545fa4a..ca28422f9ca0 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -260,10 +260,10 @@ def rewire_prompt(self, prompt, device): text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) all_text.append(text) - inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device) + inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device) - self.text_encoder.to(device) generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) + generated_ids.to(device) generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = self.text_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False