Skip to content

Commit

Permalink
ENH: Better DoRA check in mixed adapter batch inference (#2089)
Browse files Browse the repository at this point in the history
This is a bit of an edge case, but I noticed this while working on
something else.

PEFT allows mixed batch adapter inference, i.e. when predicting, the
same batch can use different adapters by passing the adapter_names
argument. However, this is not supported for DoRA (yet), so there is a
check that raises an error if DoRA is used.

Previously, this check would check all adapters for DoRA, even if those
adapters are not being used in adapter_names. This was unnecessarily
strict and with this PR, we only check the adapters that are actually
being used.
  • Loading branch information
BenjaminBossan authored Sep 24, 2024
1 parent f4cf170 commit 8f39708
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def _check_forward_args(self, x, *args, **kwargs):
msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first."
raise ValueError(msg)

unique_adapters = set(self.active_adapters)
# DoRA is not supported (yet), check that it's not being used. Don't check "__base__", as this is the
# placeholder for the base model.
unique_adapters = {name for name in adapter_names if name != "__base__"}
for adapter_name in unique_adapters:
if self.use_dora.get(adapter_name, False):
msg = "Cannot pass `adapter_names` when DoRA is enabled."
Expand Down
21 changes: 20 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,7 @@ def test_mixed_adapter_batches_lora_merged_raises(self, mlp_lora):
mlp_lora.forward(**inputs)

def test_mixed_adapter_batches_lora_with_dora_raises(self):
# When there are Dora adapters, passing adapter names should raise an error
# When there are DoRA adapters, passing adapter names should raise an error
torch.manual_seed(0)
inputs = {
"X": torch.arange(90).view(-1, 10).to(self.torch_device),
Expand All @@ -3499,6 +3499,25 @@ def test_mixed_adapter_batches_lora_with_dora_raises(self):
with pytest.raises(ValueError, match=msg):
peft_model.forward(**inputs)

def test_mixed_adapter_batches_lora_with_dora_but_dora_not_included_works(self):
# When there are DoRA adapters, passing adapter names should raise an error, see previous test. However, when
# the adapter that uses DoRA is not included in adapter_names, it's actually fine.
torch.manual_seed(0)
base_model = MLP().to(self.torch_device).eval()
config_dora = LoraConfig(target_modules=["lin0"], init_lora_weights=False, use_dora=True)
peft_model = get_peft_model(base_model, config_dora)
config_no_dora = LoraConfig(target_modules=["lin0"], init_lora_weights=False, use_dora=False)
peft_model.add_adapter(adapter_name="other", peft_config=config_no_dora)
peft_model.eval()

# The "default" adapter uses DoRA but "other" is not using it, so using "other" is fine. Also, "__base__" is
# fine since it uses the base model and thus DoRA is not involved either.
inputs = {
"X": torch.arange(90).view(-1, 10).to(self.torch_device),
"adapter_names": ["other"] * 4 + ["__base__"] * 5,
}
peft_model.forward(**inputs)

@require_non_cpu
def test_mixed_adapter_batches_lora_opt_timing(self):
# Use a more realistic model (opt-125m) and do a simple runtime check to ensure that mixed adapter batches
Expand Down

0 comments on commit 8f39708

Please sign in to comment.