From 618d2063c09031ff82d904f2c8567ee63671d4e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 23 Dec 2024 13:27:17 +0530 Subject: [PATCH] fixes to tests --- tests/lora/test_lora_layers_flux.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 9e7dd74a86f5..b22fbaaed69b 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -330,7 +330,8 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -339,6 +340,7 @@ def test_lora_parameter_expanded_shapes(self): self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + # Testing opposite direction where the LoRA params are zero-padded. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -349,15 +351,21 @@ def test_lora_parameter_expanded_shapes(self): "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, } - # We should error out because lora input features is less than original. We only - # support expanding the module, not shrinking it - with self.assertRaises(NotImplementedError): + with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - def test_lora_expanding_shape_with_normal_lora_raises_error(self): - # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but - # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error. - # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180 + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) + + def test_normal_lora_with_expanded_lora_raises_error(self): + # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then + # load shape expanded LoRA (such as Control LoRA). components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) # Change the transformer config to mimic a real use case.