Skip to content

Commit

Permalink
[PEFT] Add warning for missing key in LoRA adapter (#34068)
Browse files Browse the repository at this point in the history
When loading a LoRA adapter, so far, there was only a warning when there
were unexpected keys in the checkpoint. Now, there is also a warning
when there are missing keys.

This change is consistent with
huggingface/peft#2118 in PEFT and the planned PR
huggingface/diffusers#9622 in diffusers.

Apart from this change, the error message for unexpected keys was
slightly altered for consistency (it should be more readable now). Also,
besides adding a test for the missing keys warning, a test for
unexpected keys warning was also added, as it was missing so far.
  • Loading branch information
BenjaminBossan authored Oct 24, 2024
1 parent fe35073 commit d9989e0
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 6 deletions.
24 changes: 20 additions & 4 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,29 @@ def load_adapter(
)

if incompatible_keys is not None:
# check only for unexpected keys
err_msg = ""
origin_name = peft_model_id if peft_model_id is not None else "state_dict"
# Check for unexpected keys.
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
logger.warning(
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
f" {incompatible_keys.unexpected_keys}. "
err_msg = (
f"Loading adapter weights from {origin_name} led to unexpected keys not found in the model: "
f"{', '.join(incompatible_keys.unexpected_keys)}. "
)

# Check for missing keys.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
# Filter missing keys specific to the current adapter, as missing base model keys are expected.
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
err_msg += (
f"Loading adapter weights from {origin_name} led to missing keys in the model: "
f"{', '.join(lora_missing_keys)}"
)

if err_msg:
logger.warning(err_msg)

# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
if (
(getattr(self, "hf_device_map", None) is not None)
Expand Down
78 changes: 76 additions & 2 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from huggingface_hub import hf_hub_download
from packaging import version

from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers import AutoModelForCausalLM, OPTForCausalLM, logging
from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes,
require_peft,
require_torch,
Expand Down Expand Up @@ -72,9 +73,15 @@ def test_peft_from_pretrained(self):
This checks if we pass a remote folder that contains an adapter config and adapter weights, it
should correctly load a model that has adapters injected on it.
"""
logger = logging.get_logger("transformers.integrations.peft")

for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
with CaptureLogger(logger) as cl:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
# ensure that under normal circumstances, there are no warnings about keys
self.assertNotIn("unexpected keys", cl.out)
self.assertNotIn("missing keys", cl.out)

self.assertTrue(self._check_lora_correctly_converted(peft_model))
self.assertTrue(peft_model._hf_peft_config_loaded)
Expand Down Expand Up @@ -548,3 +555,70 @@ def test_peft_from_pretrained_hub_kwargs(self):

model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))

def test_peft_from_pretrained_unexpected_keys_warning(self):
"""
Test for warning when loading a PEFT checkpoint with unexpected keys.
"""
from peft import LoraConfig

logger = logging.get_logger("transformers.integrations.peft")

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)

# add unexpected key
dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values()))

with CaptureLogger(logger) as cl:
model.load_adapter(
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
)

msg = "Loading adapter weights from state_dict led to unexpected keys not found in the model: foobar"
self.assertIn(msg, cl.out)

def test_peft_from_pretrained_missing_keys_warning(self):
"""
Test for warning when loading a PEFT checkpoint with missing keys.
"""
from peft import LoraConfig

logger = logging.get_logger("transformers.integrations.peft")

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)

# remove a key so that we have missing keys
key = next(iter(dummy_state_dict.keys()))
del dummy_state_dict[key]

with CaptureLogger(logger) as cl:
model.load_adapter(
adapter_state_dict=dummy_state_dict,
peft_config=peft_config,
low_cpu_mem_usage=False,
adapter_name="other",
)

# Here we need to adjust the key name a bit to account for PEFT-specific naming.
# 1. Remove PEFT-specific prefix
# If merged after dropping Python 3.8, we can use: key = key.removeprefix(peft_prefix)
peft_prefix = "base_model.model."
key = key[len(peft_prefix) :]
# 2. Insert adapter name
prefix, _, suffix = key.rpartition(".")
key = f"{prefix}.other.{suffix}"

msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
self.assertIn(msg, cl.out)

0 comments on commit d9989e0

Please sign in to comment.